diff --git a/.github/scripts/filter.py b/.github/scripts/filter.py new file mode 100644 index 000000000..bbdba868a --- /dev/null +++ b/.github/scripts/filter.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + + +def main(): + """ + For filtering out certain cuda versions in the build matrix that + determines with nightly builds are run. This ensures TorchRec is + always consistent in version compatibility with FBGEMM. + """ + + full_matrix_string = os.environ["MAT"] + full_matrix = json.loads(full_matrix_string) + + new_matrix_entries = [] + + for entry in full_matrix["include"]: + new_matrix_entries.append(entry) + + new_matrix = {"include": new_matrix_entries} + print(json.dumps(new_matrix)) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/install_libs.sh b/.github/scripts/install_libs.sh new file mode 100644 index 000000000..27522ff92 --- /dev/null +++ b/.github/scripts/install_libs.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +echo "CU_VERSION: ${CU_VERSION}" +echo "CHANNEL: ${CHANNEL}" +echo "CONDA_ENV: ${CONDA_ENV}" + +if [[ $CU_VERSION = cu* ]]; then + # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not + # being able to locate libnvrtc.so + echo "[NOVA] Setting LD_LIBRARY_PATH ..." + conda env config vars set -p ${CONDA_ENV} \ + LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" +else + echo "[NOVA] Setting LD_LIBRARY_PATH ..." + conda env config vars set -p ${CONDA_ENV} \ + LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" +fi + +if [ "$CHANNEL" = "nightly" ]; then + ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/"$CU_VERSION" +elif [ "$CHANNEL" = "test" ]; then + ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION" +fi + + +${CONDA_RUN} pip install importlib-metadata diff --git a/.github/scripts/tests_to_skip.txt b/.github/scripts/tests_to_skip.txt new file mode 100644 index 000000000..7c0e95f00 --- /dev/null +++ b/.github/scripts/tests_to_skip.txt @@ -0,0 +1 @@ +_disabled_in_oss_compatibility diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh new file mode 100755 index 000000000..28060c868 --- /dev/null +++ b/.github/scripts/validate_binaries.sh @@ -0,0 +1,167 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +export PYTORCH_CUDA_PKG="" +export CONDA_ENV="build_binary" + +if [[ ${MATRIX_PYTHON_VERSION} = '3.13t' ]]; then + echo "Conda doesn't support 3.13t yet, you can just try \`conda create -n test python=3.13t\`" + exit 0 +fi + +conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}" + +conda run -n build_binary python --version + +# Install pytorch, torchrec and fbgemm as per +# installation instructions on following page +# https://github.com/pytorch/torchrec#installations + +if [[ ${MATRIX_GPU_ARCH_TYPE} = 'rocm' ]]; then + echo "We don't support rocm" + exit 0 +fi + +# figure out CUDA VERSION +if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then + if [[ ${MATRIX_GPU_ARCH_VERSION} = '11.8' ]]; then + export CUDA_VERSION="cu118" + elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.1' ]]; then + export CUDA_VERSION="cu121" + elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.6' ]]; then + export CUDA_VERSION="cu126" + elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.8' ]]; then + export CUDA_VERSION="cu128" + else + export CUDA_VERSION="cu126" + fi +else + export CUDA_VERSION="cpu" +fi + +# figure out URL +if [[ ${MATRIX_CHANNEL} = 'nightly' ]]; then + export PYTORCH_URL="https://download.pytorch.org/whl/nightly/$%7BCUDA_VERSION%7D" +elif [[ ${MATRIX_CHANNEL} = 'test' ]]; then + export PYTORCH_URL="https://download.pytorch.org/whl/test/$%7BCUDA_VERSION%7D" +elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then + export PYTORCH_URL="https://download.pytorch.org/whl/$%7BCUDA_VERSION%7D" +fi + + +echo "CU_VERSION: ${CUDA_VERSION}" +echo "MATRIX_CHANNEL: ${MATRIX_CHANNEL}" +echo "CONDA_ENV: ${CONDA_ENV}" + +# shellcheck disable=SC2155 +export CONDA_PREFIX=$(conda run -n "${CONDA_ENV}" printenv CONDA_PREFIX) + + +# Set LD_LIBRARY_PATH to fix the runtime error with fbgemm_gpu not +# being able to locate libnvrtc.so +# NOTE: The order of the entries in LD_LIBRARY_PATH matters +echo "[NOVA] Setting LD_LIBRARY_PATH ..." +conda env config vars set -n ${CONDA_ENV} \ + LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:/usr/local/lib:/usr/lib64:${LD_LIBRARY_PATH}" + + +# install pytorch +# switch back to conda once torch nightly is fixed +# if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then +# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}" +# fi + +conda run -n "${CONDA_ENV}" pip install importlib-metadata + +conda run -n "${CONDA_ENV}" pip install torch --index-url "$PYTORCH_URL" + +# install fbgemm +conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL" + +# install requirements from pypi +conda run -n "${CONDA_ENV}" pip install torchmetrics==1.0.3 + +# install tensordict from pypi +conda run -n "${CONDA_ENV}" pip install tensordict==0.7.1 + +# install torchrec +conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL" + +# Run small import test +conda run -n "${CONDA_ENV}" python -c "import torch; import fbgemm_gpu; import torchrec" + +# check directory +ls -R + +# check if cuda available +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())" + +# check cuda version +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)" + +# Finally run smoke test +# python 3.11 needs torchx-nightly +conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath +if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then + conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py +else + conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --script test_installation.py -- --cpu_only +fi + + +# redo for pypi release + +if [[ ${MATRIX_CHANNEL} != 'release' ]]; then + exit 0 +else + # Check version matches only for release binaries + torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2) + fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2) + + if [ "$torchrec_version" != "$fbgemm_version" ]; then + echo "Error: TorchRec package version does not match FBGEMM package version" + exit 1 + fi +fi + +conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}" + +conda run -n "${CONDA_ENV}" python --version + +if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.4' ]]; then + exit 0 +fi + +echo "checking pypi release" +conda run -n "${CONDA_ENV}" pip install torch +conda run -n "${CONDA_ENV}" pip install fbgemm-gpu +conda run -n "${CONDA_ENV}" pip install torchrec + +# Check version matching again for PyPI +torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2) +fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2) + +if [ "$torchrec_version" != "$fbgemm_version" ]; then + echo "Error: TorchRec package version does not match FBGEMM package version" + exit 1 +fi + +# check directory +ls -R + +# check if cuda available +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())" + +# check cuda version +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)" + +# python 3.11 needs torchx-nightly +conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath + +# Finally run smoke test +conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml new file mode 100644 index 000000000..b65648ac2 --- /dev/null +++ b/.github/workflows/build-wheels-linux.yml @@ -0,0 +1,64 @@ +name: Build Linux Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-rocm: false + filter-matrix: + needs: generate-matrix + runs-on: linux.20_04.4x + outputs: + matrix: ${{ steps.filter.outputs.matrix }} + steps: + - uses: actions/setup-python@v4 + - name: Checkout torchrec repository + uses: actions/checkout@v4 + with: + repository: pytorch/torchrec + - name: Filter Generated Built Matrix + id: filter + env: + MAT: ${{ needs.generate-matrix.outputs.matrix }} + run: | + set -ex + pwd + ls + MATRIX_BLOB="$(python .github/scripts/filter.py)" + echo "${MATRIX_BLOB}" + echo "matrix=${MATRIX_BLOB}" >> "${GITHUB_OUTPUT}" + build: + needs: filter-matrix + name: pytorch/torchrec + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: pytorch/torchrec + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.filter-matrix.outputs.matrix }} + pre-script: "" + post-script: .github/scripts/install_libs.sh + package-name: torchrec + smoke-test-script: "" + trigger-event: ${{ github.event_name }} diff --git a/.github/workflows/build_dynamic_embedding_wheels.yml b/.github/workflows/build_dynamic_embedding_wheels.yml index cdf4d5f76..da8174812 100644 --- a/.github/workflows/build_dynamic_embedding_wheels.yml +++ b/.github/workflows/build_dynamic_embedding_wheels.yml @@ -20,27 +20,52 @@ jobs: fail-fast: false matrix: os: [ ubuntu-latest ] - pyver: [ cp37, cp38, cp39, cp310 ] - cuver: [ "11.6", "11.3"] + pyver: [ cp39, cp310, cp311, cp312 ] + cuver: [ "12.1", "12.4"] steps: - - uses: actions/checkout@v3 + - + name: Check disk space + run: df . -h + + - name: Remove unnecessary files + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + - + name: Check disk space + run: df . -h + + - uses: actions/checkout@v4 with: submodules: recursive - - uses: pypa/cibuildwheel@v2.8.0 - with: + - uses: pypa/cibuildwheel@v2.20.0 + with: package-dir: contrib/dynamic_embedding env: CIBW_BEFORE_BUILD: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/before_linux_build.sh" CIBW_BUILD: "${{ matrix.pyver }}-manylinux_x86_64" CIBW_REPAIR_WHEEL_COMMAND: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/repair_wheel.sh {wheel} {dest_dir}" + CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - name: Verify clean directory run: git diff --exit-code shell: bash - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: + name: artifact-${{ matrix.os }}-${{ matrix.pyver }}-cu${{ matrix.cuver }} path: wheelhouse/*.whl + + merge: + runs-on: ubuntu-latest + needs: build_wheels + steps: + - name: Merge Artifacts + uses: actions/upload-artifact/merge@v4 + with: + name: artifact + pattern: artifact-* diff --git a/.github/workflows/cpp_unittest_ci_cpu.yml b/.github/workflows/cpp_unittest_ci_cpu.yml new file mode 100644 index 000000000..ad9e6e13f --- /dev/null +++ b/.github/workflows/cpp_unittest_ci_cpu.yml @@ -0,0 +1,63 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: CPU Unit Test C++ CI + +on: + push: + paths-ignore: + - "docs/*" + - "third_party/*" + - .gitignore + - "*.md" + pull_request: + paths-ignore: + - "docs/*" + - "third_party/*" + - .gitignore + - "*.md" + +jobs: + build_test: + strategy: + fail-fast: false + matrix: + include: + - os: linux.2xlarge + python-version: 3.9 + python-tag: "py39" + - os: linux.2xlarge + python-version: '3.10' + python-tag: "py310" + - os: linux.2xlarge + python-version: '3.11' + python-tag: "py311" + - os: linux.2xlarge + python-version: '3.12' + python-tag: "py312" + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: ${{ matrix.os }} + timeout: 15 + script: | + ldd --version + conda create -y --name build_binary python=${{ matrix.python-version }} + conda info + python --version + echo "Starting C++ Tests" + conda install -n build_binary -y gxx_linux-64 + conda run -n build_binary \ + x86_64-conda-linux-gnu-g++ --version + conda install -n build_binary -c anaconda redis -y + conda run -n build_binary redis-server --daemonize yes + mkdir cpp-build + cd cpp-build + conda run -n build_binary cmake \ + -DBUILD_TEST=ON \ + -DBUILD_REDIS_IO=ON \ + -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake .. + conda run -n build_binary make -j + conda run -n build_binary ctest -V . diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c69be055e..360668d4b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -5,24 +5,30 @@ on: branches: - main workflow_dispatch: + pull_request: + jobs: build_docs_job: runs-on: ${{ matrix.os }} + permissions: + # Grant write permission here so that the doc can be pushed to gh-pages branch + contents: write strategy: matrix: include: - - os: linux.2xlarge - python-version: 3.7 + - os: linux.20_04.4x + python-version: 3.9 + python-tag: "py39" steps: - name: Check ldd --version run: ldd --version - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 # Update references - name: Update pip run: | - sudo yum update -y - sudo yum -y install git python3-pip + sudo apt-get update + sudo apt-get -y install python3-pip sudo pip3 install --upgrade pip - name: Setup conda run: | @@ -45,17 +51,24 @@ jobs: - name: Install gcc shell: bash run: | - sudo yum group install -y "Development Tools" + sudo apt-get install build-essential - name: setup Path run: | echo /usr/local/bin >> $GITHUB_PATH - name: Install PyTorch shell: bash run: | - conda run -n build_binary python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + conda run -n build_binary pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu + - name: Install fbgemm + run: | + conda run -n build_binary pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu + - name: Install torchmetrics + run: | + conda run -n build_binary pip install torchmetrics==1.0.3 - name: Install TorchRec run: | - conda run -n build_binary python -m pip install --pre torchrec_nightly_cpu -f https://download.pytorch.org/whl/nightly/torchrec_nightly_cpu/index.html + conda run -n build_binary pip install -r requirements.txt + conda run -n build_binary python setup.py bdist_wheel --python-tag=${{ matrix.python-tag }} - name: Test fbgemm_gpu and torchrec installation shell: bash run: | @@ -69,11 +82,42 @@ jobs: cd ./docs conda run -n build_binary make html cd .. + - name: Upload Built-Docs + uses: actions/upload-artifact@v4 + with: + name: Built-Docs + path: docs/build/html/ - name: Get output time run: echo "The time was ${{ steps.build.outputs.time }}" - name: Deploy + if: github.ref == 'refs/heads/main' uses: JamesIves/github-pages-deploy-action@releases/v3 with: ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} BRANCH: gh-pages # The branch the action should deploy to. FOLDER: docs/build/html # The folder the action should deploy. + + doc-preview: + runs-on: [linux.2xlarge] + needs: build_docs_job + if: ${{ github.event_name == 'pull_request' }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Download artifact + uses: actions/download-artifact@v4 + with: + name: Built-Docs + path: docs + - name: Add no-index tag + run: | + find docs -name "*.html" -print0 | xargs -0 sed -i '//a \ \ '; + - name: Upload docs preview + uses: seemethere/upload-artifact-s3@v5 + if: ${{ github.event_name == 'pull_request' }} + with: + retention-days: 14 + s3-bucket: doc-previews + if-no-files-found: error + path: docs + s3-prefix: pytorch/torchrec/${{ github.event.pull_request.number }} diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml deleted file mode 100644 index be58565f1..000000000 --- a/.github/workflows/nightly_build.yml +++ /dev/null @@ -1,207 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Push Binary Nightly - -on: - workflow_call: - secrets: - AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: - required: true - AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: - required: true - PYPI_TOKEN: - required: false - # run every day at 11:15am - schedule: - - cron: '15 11 * * *' - # or manually trigger it - workflow_dispatch: - -jobs: - # build on cpu hosts and upload to GHA - build_on_cpu: - runs-on: ${{ matrix.os }} - strategy: - matrix: - include: - - os: linux.2xlarge - python-version: 3.7 - python-tag: "py37" - cuda-tag: "cu11" - - os: linux.2xlarge - python-version: 3.8 - python-tag: "py38" - cuda-tag: "cu11" - - os: linux.2xlarge - python-version: 3.9 - python-tag: "py39" - cuda-tag: "cu11" - - os: linux.2xlarge - python-version: '3.10' - python-tag: "py310" - cuda-tag: "cu11" - steps: - # Checkout the repository to the GitHub Actions runner - - name: Check ldd --version - run: ldd --version - - name: Checkout - uses: actions/checkout@v2 - - name: Update pip - run: | - sudo yum update -y - sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - - name: Setup conda - run: | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - bash ~/miniconda.sh -b -p $HOME/miniconda -u - - name: setup Path - run: | - echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH - echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH - - name: create conda env - run: | - conda create --name build_binary python=${{ matrix.python-version }} - conda info - - name: check python version no Conda - run: | - python --version - - name: check python version - run: | - conda run -n build_binary python --version - - name: Install C/C++ compilers - run: | - sudo yum install -y gcc gcc-c++ - - name: Install PyTorch and CUDA - shell: bash - run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia - - name: Install Dependencies - shell: bash - run: | - conda run -n build_binary python -m pip install -r requirements.txt - - name: Test Installation of dependencies - run: | - conda run -n build_binary python -c "import torch.distributed" - echo "torch.distributed succeeded" - conda run -n build_binary python -c "import skbuild" - echo "skbuild succeeded" - conda run -n build_binary python -c "import numpy" - echo "numpy succeeded" - # for the conda run with quotes, we have to use "\" and double quotes - # here is the issue: https://github.com/conda/conda/issues/10972 - - name: Build TorchRec Nightly - run: | - rm -r dist || true - conda run -n build_binary \ - python setup.py bdist_wheel \ - --package_name torchrec-nightly \ - --python-tag=${{ matrix.python-tag }} - - name: Upload wheel as GHA artifact - uses: actions/upload-artifact@v2 - with: - name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - path: dist/torchrec_nightly-*.whl - - # download from GHA, test on gpu and push to pypi - test_on_gpu: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [linux.4xlarge.nvidia.gpu] - python-version: [3.7, 3.8, 3.9] - cuda-tag: ["cu11"] - needs: build_on_cpu - # the glibc version should match the version of the one we used to build the binary - # for this case, it's 2.26 - steps: - - name: Check ldd --version - run: ldd --version - - name: check cpu info - shell: bash - run: | - cat /proc/cpuinfo - - name: check distribution info - shell: bash - run: | - cat /proc/version - - name: Display EC2 information - shell: bash - run: | - set -euo pipefail - function get_ec2_metadata() { - # Pulled from instance metadata endpoint for EC2 - # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html - category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D" - } - echo "ami-id: $(get_ec2_metadata ami-id)" - echo "instance-id: $(get_ec2_metadata instance-id)" - echo "instance-type: $(get_ec2_metadata instance-type)" - - name: check gpu info - shell: bash - run: | - sudo yum install lshw -y - sudo lshw -C display - # Checkout the repository to the GitHub Actions runner - - name: Checkout - uses: actions/checkout@v2 - - name: Update pip - run: | - sudo yum update -y - sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - - name: Setup conda - run: | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - bash ~/miniconda.sh -b -p $HOME/miniconda - - name: setup Path - run: | - echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH - echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH - - name: create conda env - run: | - conda create --name build_binary python=${{ matrix.python-version }} - conda info - - name: check python version no Conda - run: | - python --version - - name: check python version - run: | - conda run -n build_binary python --version - - name: Install C/C++ compilers - run: | - sudo yum install -y gcc gcc-c++ - - name: Install PyTorch and CUDA - shell: bash - run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia - # download wheel from GHA - - name: Download wheel - uses: actions/download-artifact@v2 - with: - name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - - name: Display structure of downloaded files - run: ls -R - - name: Install TorchRec Nightly - run: | - rm -r dist || true - conda run -n build_binary python -m pip install *.whl - - name: Test torchrec installation - shell: bash - run: | - conda run -n build_binary \ - python -c "import torchrec" - - name: Push TorchRec Binary to PYPI - env: - PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} - run: | - conda run -n build_binary python -m pip install twine - conda run -n build_binary \ - python -m twine upload \ - --username __token__ \ - --password "$PYPI_TOKEN" \ - --skip-existing \ - torchrec_nightly-*.whl \ - --verbose diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index eff5ce19b..74a5af78c 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -10,15 +10,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 architecture: x64 packages: | - ufmt==1.3.2 - black==22.3.0 - usort==1.0.2 + ufmt==2.5.1 + black==24.2.0 + usort==1.0.8 - name: Checkout Torchrec - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run pre-commit - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml new file mode 100644 index 000000000..fd773c787 --- /dev/null +++ b/.github/workflows/pyre.yml @@ -0,0 +1,27 @@ +name: Pyre Check + +on: + push: + branches: [main] + pull_request: + +jobs: + pyre-check: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash -el {0} + steps: + - uses: conda-incubator/setup-miniconda@v2 + with: + python-version: 3.9 + - name: Checkout Torchrec + uses: actions/checkout@v4 + - name: Install dependencies + run: > + pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu && + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu && + pip install -r requirements.txt && + pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g') + - name: Pyre check + run: pyre check diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml index 58c7c55fb..1ea837d4b 100644 --- a/.github/workflows/release_build.yml +++ b/.github/workflows/release_build.yml @@ -6,10 +6,6 @@ name: Push Binary Release on: workflow_call: secrets: - AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: - required: true - AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: - required: true PYPI_TOKEN: required: false workflow_dispatch: @@ -22,33 +18,32 @@ jobs: strategy: matrix: include: - - os: linux.2xlarge - python-version: 3.7 - python-tag: "py37" - cuda-tag: "cu11" - - os: linux.2xlarge - python-version: 3.8 - python-tag: "py38" - cuda-tag: "cu11" - os: linux.2xlarge python-version: 3.9 python-tag: "py39" - cuda-tag: "cu11" + cuda-tag: "cu124" - os: linux.2xlarge python-version: '3.10' python-tag: "py310" - cuda-tag: "cu11" + cuda-tag: "cu124" + - os: linux.2xlarge + python-version: '3.11' + python-tag: "py311" + cuda-tag: "cu124" + - os: linux.2xlarge + python-version: '3.12' + python-tag: "py312" + cuda-tag: "cu124" steps: # Checkout the repository to the GitHub Actions runner - name: Check ldd --version run: ldd --version - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Update pip run: | sudo yum update -y sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - name: Setup conda run: | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh @@ -73,7 +68,12 @@ jobs: - name: Install PyTorch and CUDA shell: bash run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia + conda run -n build_binary pip install torch + - name: Install fbgemm + shell: bash + run: | + conda run -n build_binary pip install numpy + conda run -n build_binary pip install fbgemm-gpu - name: Install Dependencies shell: bash run: | @@ -89,26 +89,27 @@ jobs: # for the conda run with quotes, we have to use "\" and double quotes # here is the issue: https://github.com/conda/conda/issues/10972 - name: Build TorchRec + env: + OFFICIAL_RELEASE: 1 run: | rm -r dist || true conda run -n build_binary \ python setup.py bdist_wheel \ - --package_name torchrec \ --python-tag=${{ matrix.python-tag }} - name: Upload wheel as GHA artifact - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl path: dist/torchrec-*.whl - # download from GHA, test on gpu and push to pypi - test_on_gpu: + # download from GHA, sanity check on gpu and push to pypi + sanity_check_on_gpu_and_push: runs-on: ${{ matrix.os }} strategy: matrix: - os: [linux.4xlarge.nvidia.gpu] - python-version: [3.7, 3.8, 3.9] - cuda-tag: ["cu11"] + os: [linux.g5.12xlarge.nvidia.gpu] + python-version: [3.9, "3.10", "3.11", "3.12"] + cuda-tag: ["cu124"] needs: build_on_cpu # the glibc version should match the version of the one we used to build the binary # for this case, it's 2.26 @@ -131,7 +132,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -143,12 +144,11 @@ jobs: sudo lshw -C display # Checkout the repository to the GitHub Actions runner - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Update pip run: | sudo yum update -y sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - name: Setup conda run: | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh @@ -173,10 +173,19 @@ jobs: - name: Install PyTorch and CUDA shell: bash run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia + conda run -n build_binary pip install torch # download wheel from GHA + - name: Install fbgemm + shell: bash + run: | + conda run -n build_binary pip install numpy + conda run -n build_binary pip install fbgemm-gpu + - name: Install torchmetrics + shell: bash + run: | + conda run -n build_binary pip install torchmetrics==1.0.3 - name: Download wheel - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 with: name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - name: Display structure of downloaded files @@ -189,13 +198,9 @@ jobs: shell: bash run: | conda run -n build_binary \ - python -c "import torchrec" - - name: Test with pytest - run: | - conda run -n build_binary \ - python -m pip install pytest + python -c "import fbgemm_gpu" conda run -n build_binary \ - python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors + python -c "import torchrec" # Push to Pypi - name: Push TorchRec Binary to PYPI env: diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml index 82cbf5621..5ea25f9d4 100644 --- a/.github/workflows/unittest_ci.yml +++ b/.github/workflows/unittest_ci.yml @@ -4,188 +4,120 @@ name: Unit Test CI on: - # TODO: re-enable when GPU unit tests are working - # push: - # paths-ignore: - # - "docs/*" - # - "third_party/*" - # - .gitignore - # - "*.md" - # pull_request: - # paths-ignore: - # - "docs/*" - # - "third_party/*" - # - .gitignore - # - "*.md" + push: + branches: + - nightly + - main workflow_dispatch: jobs: - # build on cpu hosts and upload to GHA - build_on_cpu: - runs-on: ${{ matrix.os }} + build_test: strategy: + fail-fast: false matrix: include: - - os: linux.2xlarge - # ideally we run on 3.8 and 3.9 as well, however we are limited in resources. - python-version: 3.7 - python-tag: "py37" - cuda-tag: "cu11" - steps: - # Checkout the repository to the GitHub Actions runner - - name: Check ldd --version - run: ldd --version - - name: Checkout - uses: actions/checkout@v2 - - name: Update pip - run: | - sudo yum update -y - sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - - name: Setup conda - run: | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - bash ~/miniconda.sh -b -p $HOME/miniconda -u - - name: setup Path - run: | - echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH - echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH - - name: create conda env - run: | - conda create --name build_binary python=${{ matrix.python-version }} + - os: linux.g5.12xlarge.nvidia.gpu + python-version: 3.9 + python-tag: "py39" + cuda-tag: "cu118" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: 3.9 + python-tag: "py39" + cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: 3.9 + python-tag: "py39" + cuda-tag: "cu124" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.10' + python-tag: "py310" + cuda-tag: "cu118" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.10' + python-tag: "py310" + cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.10' + python-tag: "py310" + cuda-tag: "cu124" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.11' + python-tag: "py311" + cuda-tag: "cu118" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.11' + python-tag: "py311" + cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.11' + python-tag: "py311" + cuda-tag: "cu124" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.12' + python-tag: "py312" + cuda-tag: "cu118" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.12' + python-tag: "py312" + cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.12' + python-tag: "py312" + cuda-tag: "cu124" + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: ${{ matrix.os }} + timeout: 30 + script: | + ldd --version + conda create -y --name build_binary python=${{ matrix.python-version }} conda info - - name: check python version no Conda - run: | python --version - - name: check python version - run: | conda run -n build_binary python --version - - name: Install C/C++ compilers - run: | - sudo yum install -y gcc gcc-c++ - - name: Install PyTorch and CUDA - shell: bash - run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia - - name: Install Dependencies - shell: bash - run: | - conda run -n build_binary python -m pip install -r requirements.txt - - name: Test Installation of dependencies - run: | - conda run -n build_binary python -c "import torch.distributed" - echo "torch.distributed succeeded" - conda run -n build_binary python -c "import skbuild" - echo "skbuild succeeded" - conda run -n build_binary python -c "import numpy" - echo "numpy succeeded" - # for the conda run with quotes, we have to use "\" and double quotes - # here is the issue: https://github.com/conda/conda/issues/10972 - - name: Build TorchRec Binary - run: | + conda run -n build_binary \ + pip install torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }} + conda run -n build_binary \ + python -c "import torch" + echo "torch succeeded" + conda run -n build_binary \ + python -c "import torch.distributed" + conda run -n build_binary \ + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }} + conda run -n build_binary \ + python -c "import fbgemm_gpu" + echo "fbgemm_gpu succeeded" + conda run -n build_binary \ + pip install -r requirements.txt conda run -n build_binary \ python setup.py bdist_wheel \ - --package_name torchrec-test \ --python-tag=${{ matrix.python-tag }} - - name: Upload wheel as GHA artifact - uses: actions/upload-artifact@v2 - with: - name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - path: dist/torchrec-test-*.whl - - # download from GHA, test on gpu - test_on_gpu: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [linux.4xlarge.nvidia.gpu] - python-version: [3.7] - cuda-tag: ["cu11"] - needs: build_on_cpu - # the glibc version should match the version of the one we used to build the binary - # for this case, it's 2.26 - steps: - - name: Check ldd --version - # Run unit tests - run: ldd --version - - name: check cpu info - shell: bash - run: | - cat /proc/cpuinfo - - name: check distribution info - shell: bash - run: | - cat /proc/version - - name: Display EC2 information - shell: bash - run: | - set -euo pipefail - function get_ec2_metadata() { - # Pulled from instance metadata endpoint for EC2 - # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html - category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D" - } - echo "ami-id: $(get_ec2_metadata ami-id)" - echo "instance-id: $(get_ec2_metadata instance-id)" - echo "instance-type: $(get_ec2_metadata instance-type)" - - name: check gpu info - shell: bash - run: | - sudo yum install lshw -y - sudo lshw -C display - # Checkout the repository to the GitHub Actions runner - - name: Checkout - uses: actions/checkout@v2 - - name: Update pip - run: | - sudo yum update -y - sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - - name: Setup conda - run: | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - bash ~/miniconda.sh -b -p $HOME/miniconda - - name: setup Path - run: | - echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH - echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH - - name: create conda env - run: | - conda create --name build_binary python=${{ matrix.python-version }} - conda info - - name: check python version no Conda - run: | - python --version - - name: check python version - run: | - conda run -n build_binary python --version - - name: Install C/C++ compilers - run: | - sudo yum install -y gcc gcc-c++ - - name: Install PyTorch and CUDA - shell: bash - run: | - conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia - # download wheel from GHA - - name: Download wheel - uses: actions/download-artifact@v2 - with: - name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - - name: Display structure of downloaded files - run: ls -R - - name: Install TorchRec GPU - run: | - rm -r dist || true - conda run -n build_binary python -m pip install dist/*.whl - - name: Test torchrec installation - shell: bash - run: | conda run -n build_binary \ python -c "import torchrec" - - name: Test with pytest - run: | + echo "torch.distributed succeeded" conda run -n build_binary \ - python -m pip install pytest + python -c "import numpy" + echo "numpy succeeded" + conda install -n build_binary -y pytest + # Read the list of tests to skip from a file, ignoring empty lines and comments + skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt) + # Check if skip_expression is effectively empty + if [ -z "$skip_expression" ]; then + skip_expression="" + else + skip_expression=${skip_expression:5} # Remove the leading " and " + fi conda run -n build_binary \ - python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors + python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \ + --ignore=torchrec/distributed/tests/test_comm.py --ignore=torchrec/distributed/tests/test_infer_shardings.py \ + --ignore=torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py --ignore=torchrec/distributed/tests/test_pt2_multiprocess.py \ + --ignore=torchrec/distributed/tests/test_pt2.py --ignore=torchrec/distributed/tests/test_quant_model_parallel.py \ + --ignore=torchrec/distributed/tests/test_quant_pruning.py --ignore=torchrec/distributed/tests/test_quant_sequence_model_parallel.py \ + --ignore-glob='torchrec/metrics/*' --ignore-glob='torchrec/distributed/tests/test_model_parallel_gloo*' \ + --ignore-glob='torchrec/inference/inference_legacy/tests*' --ignore-glob='*test_model_parallel_nccl*' \ + --ignore=torchrec/distributed/tests/test_cache_prefetch.py --ignore=torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py \ + --ignore=torchrec/distributed/tests/test_infer_utils.py --ignore=torchrec/distributed/tests/test_fx_jit.py --ignore-glob=**/test_utils/ \ + --ignore-glob='*test_train_pipeline*' --ignore=torchrec/distributed/tests/test_model_parallel_hierarchical.py \ + -k "$skip_expression" diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml index e69c4f72f..0c53bbe7d 100644 --- a/.github/workflows/unittest_ci_cpu.yml +++ b/.github/workflows/unittest_ci_cpu.yml @@ -20,70 +20,66 @@ on: jobs: build_test: strategy: + fail-fast: false matrix: include: - - os: linux.2xlarge - python-version: 3.7 - python-tag: "py37" - - os: linux.2xlarge - python-version: 3.8 - python-tag: "py38" - os: linux.2xlarge python-version: 3.9 python-tag: "py39" - os: linux.2xlarge python-version: '3.10' python-tag: "py310" - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + - os: linux.2xlarge + python-version: '3.11' + python-tag: "py311" + - os: linux.2xlarge + python-version: '3.12' + python-tag: "py312" + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: ${{ matrix.os }} - timeout: 30 + timeout: 15 script: | ldd --version conda create -y --name build_binary python=${{ matrix.python-version }} conda info python --version conda run -n build_binary python --version - conda install -n build_binary \ - --yes \ - -c pytorch-nightly \ - "pytorch-nightly"::pytorch[build="*cpu*"] conda run -n build_binary \ - pip install -r requirements.txt + pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu conda run -n build_binary \ - pip uninstall fbgemm_gpu-nightly -y - conda run -n build_binary \ - pip install fbgemm-gpu-nightly-cpu + python -c "import torch" + echo "torch succeeded" conda run -n build_binary \ python -c "import torch.distributed" - echo "torch.distributed succeeded" conda run -n build_binary \ - python -c "import skbuild" - echo "skbuild succeeded" - conda run -n build_binary \ - python -c "import numpy" - echo "numpy succeeded" + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu conda run -n build_binary \ python -c "import fbgemm_gpu" echo "fbgemm_gpu succeeded" - + conda run -n build_binary \ + pip install -r requirements.txt conda run -n build_binary \ python setup.py bdist_wheel \ - --package_name torchrec-test-cpu \ --python-tag=${{ matrix.python-tag }} conda run -n build_binary \ python -c "import torchrec" - conda install -n build_binary -y pytest + echo "torch.distributed succeeded" conda run -n build_binary \ - python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors -k 'not test_sharding_gloo_cw' - echo "Starting C++ Tests" - conda install -n build_binary -y gxx_linux-64 + python -c "import numpy" + echo "numpy succeeded" + conda install -n build_binary -y pytest + # Read the list of tests to skip from a file, ignoring empty lines and comments + skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt) + # Check if skip_expression is effectively empty + if [ -z "$skip_expression" ]; then + skip_expression="" + else + skip_expression=${skip_expression:5} # Remove the leading " and " + fi conda run -n build_binary \ - x86_64-conda-linux-gnu-g++ --version - mkdir cpp-build - cd cpp-build - conda run -n build_binary cmake \ - -DBUILD_TEST=ON \ - -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake .. - conda run -n build_binary make -j - conda run -n build_binary ctest -v . + python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \ + --ignore-glob=**/test_utils/ -k "$skip_expression" diff --git a/.github/workflows/validate-binaries.yml b/.github/workflows/validate-binaries.yml new file mode 100644 index 000000000..248857214 --- /dev/null +++ b/.github/workflows/validate-binaries.yml @@ -0,0 +1,41 @@ +name: Validate binaries + +on: + workflow_call: + inputs: + channel: + description: "Channel to use (nightly, release)" + required: false + type: string + default: release + ref: + description: 'Reference to checkout, defaults to empty' + default: "" + required: false + type: string + workflow_dispatch: + inputs: + channel: + description: "Channel to use (nightly, release, test, pypi)" + required: true + type: choice + options: + - release + - nightly + - test + ref: + description: 'Reference to checkout, defaults to empty' + default: "" + required: false + type: string + +jobs: + validate-binaries: + uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main + with: + package_type: "wheel" + os: "linux" + channel: ${{ inputs.channel }} + repository: "pytorch/torchrec" + smoke_test: "source ./.github/scripts/validate_binaries.sh" + with_cuda: enable diff --git a/.github/workflows/validate-nightly-binaries.yml b/.github/workflows/validate-nightly-binaries.yml new file mode 100644 index 000000000..6d6369495 --- /dev/null +++ b/.github/workflows/validate-nightly-binaries.yml @@ -0,0 +1,26 @@ +# Scheduled validation of the nightly binaries +name: validate-nightly-binaries + +on: + schedule: + # At 5:30 pm UTC (7:30 am PDT) + - cron: "30 17 * * *" + # Have the ability to trigger this job manually through the API + workflow_dispatch: + push: + branches: + - main + paths: + - .github/workflows/validate-nightly-binaries.yml + - .github/workflows/validate-binaries.yml + - .github/scripts/validate-binaries.sh + pull_request: + paths: + - .github/workflows/validate-nightly-binaries.yml + - .github/workflows/validate-binaries.yml + - .github/scripts/validate-binaries.sh +jobs: + nightly: + uses: ./.github/workflows/validate-binaries.yml + with: + channel: nightly diff --git a/.lintrunner.toml b/.lintrunner.toml index 23a76193a..8fd1bbc1f 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -11,7 +11,7 @@ init_command = [ 'python3', 'tools/lint/pip_init.py', '--dry-run={{DRYRUN}}', - 'black==22.3.0', + 'black==24.2.0', ] is_formatter = true @@ -28,6 +28,6 @@ init_command = [ 'python3', 'tools/lint/pip_init.py', '--dry-run={{DRYRUN}}', - 'usort==1.0.2', + 'usort==1.0.8', ] is_formatter = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea648c6ef..705694ea9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.5.0 hooks: - id: check-toml - id: check-yaml @@ -8,9 +8,9 @@ repos: - id: end-of-file-fixer - repo: https://github.com/omnilib/ufmt - rev: v1.3.2 + rev: v2.5.1 hooks: - id: ufmt additional_dependencies: - - black == 22.3.0 - - usort == 1.0.2 + - black == 24.2.0 + - usort == 1.0.8.post1 diff --git a/.pyre_configuration b/.pyre_configuration new file mode 100644 index 000000000..4249f9c02 --- /dev/null +++ b/.pyre_configuration @@ -0,0 +1,17 @@ +{ + "exclude": [ + ".*/pyre-check/stubs/.*", + ".*/torchrec/datasets*", + ".*/torchrec/models*", + ".*/torchrec/inference/client.py" + ], + "site_package_search_strategy": "all", + "source_directories": [ + { + "import_root": ".", + "source": "torchrec" + } + ], + "strict": true, + "version": "0.0.101729681899" +} diff --git a/README.MD b/README.MD index 69c40d6ca..44fc026f6 100644 --- a/README.MD +++ b/README.MD @@ -1,85 +1,98 @@ -# TorchRec (Beta Release) -[Docs](https://pytorch.org/torchrec/) +# TorchRec -TorchRec is a PyTorch domain library built to provide common sparsity & parallelism primitives needed for large-scale recommender systems (RecSys). It allows authors to train models with large embedding tables sharded across many GPUs. +**TorchRec** is a PyTorch domain library built to provide common sparsity and parallelism primitives needed for large-scale recommender systems (RecSys). TorchRec allows training and inference of models with large embedding tables sharded across many GPUs and **powers many production RecSys models at Meta**. -## TorchRec contains: +## External Presence +TorchRec has been used to accelerate advancements in recommendation systems, some examples: +* [Latest version of Meta's DLRM (Deep Learning Recommendation Model)](https://github.com/facebookresearch/dlrm) is built using TorchRec +* [Disaggregated Multi-Tower: Topology-aware Modeling Technique for Efficient Large-Scale Recommendation](https://arxiv.org/abs/2403.00877) paper +* [The Algorithm ML](https://github.com/twitter/the-algorithm-ml) from Twitter +* [Training Recommendation Models with Databricks](https://docs.databricks.com/en/machine-learning/train-recommender-models.html) +* [Toward 100TB model with Embedding Offloading Paper](https://dl.acm.org/doi/10.1145/3640457.3688037) + + +## Introduction + +To begin learning about TorchRec, check out: +* Our complete [TorchRec Tutorial](https://pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html) +* The [TorchRec documentation](https://pytorch.org/torchrec/) for an overview of TorchRec and API references + + +### TorchRec Features - Parallelism primitives that enable easy authoring of large, performant multi-device/multi-node models using hybrid data-parallelism/model-parallelism. -- The TorchRec sharder can shard embedding tables with different sharding strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, and column-wise sharding. -- The TorchRec planner can automatically generate optimized sharding plans for models. -- Pipelined training overlaps dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance. -- Optimized kernels for RecSys powered by FBGEMM. -- Quantization support for reduced precision training and inference. +- Sharders to shard embedding tables with different strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, column-wise, and table-wise-column-wise sharding. +- Planner that can automatically generate optimized sharding plans for models. +- Pipelined training overlapping dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance. +- Optimized kernels for RecSys powered by [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main). +- Quantization support for reduced precision training and inference, along with optimizing a TorchRec model for C++ inference. - Common modules for RecSys. -- Production-proven model architectures for RecSys. - RecSys datasets (criteo click logs and movielens) - Examples of end-to-end training such the dlrm event prediction model trained on criteo click logs dataset. -# Installation -Torchrec requires Python >= 3.7 and CUDA >= 11.0 (CUDA is highly recommended for performance but not required). The example below shows how to install with CUDA 11.6. This setup assumes you have conda installed. +## Installation -## Binaries +Check out the [Getting Started](https://pytorch.org/torchrec/setup-torchrec.html) section in the documentation for recommended ways to set up Torchrec. -Experimental binary on Linux for Python 3.7, 3.8 and 3.9 can be installed via pip wheels +### From Source -### Installations -``` -TO use the library without cuda, use the *-cpu fbgemm installations. However, this will be much slower than the CUDA variant. +**Generally, there isn't a need to build from source**. For most use cases, follow the section above to set up TorchRec. However, to build from source and to get the latest changes, do the following: -Nightly +1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/). + ``` + CUDA 12.4 -conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia -pip install torchrec_nightly + pip install torch --index-url https://download.pytorch.org/whl/nightly/cu124 -Stable + CUDA 12.1 -conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install torchrec + pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 -If you have no CUDA device: + CUDA 11.8 -Nightly + pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 -pip uninstall fbgemm-gpu-nightly -y -pip install fbgemm-gpu-nightly-cpu + CPU -Stable + pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu + ``` -pip uninstall fbgemm-gpu -y -pip install fbgemm-gpu-cpu +2. Clone TorchRec. + ``` + git clone --recursive https://github.com/pytorch/torchrec + cd torchrec + ``` -``` +3. Install FBGEMM. + ``` + CUDA 12.4 + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu124 -### Colab example: introduction + install -See our colab notebook for an introduction to torchrec which includes runnable installation. - - [Tutorial Source](https://github.com/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb) - - Open in [Google Colab](https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb) + CUDA 12.1 -## From Source + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121 -We are currently iterating on the setup experience. For now, we provide manual instructions on how to build from source. The example below shows how to install with CUDA 11.3. This setup assumes you have conda installed. + CUDA 11.8 -1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/) - ``` - conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu118 + + CPU + + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu ``` -2. Install Requirements +4. Install other requirements. ``` pip install -r requirements.txt ``` -3. Download and install TorchRec. +4. Install TorchRec. ``` - git clone --recursive https://github.com/pytorch/torchrec - - cd torchrec python setup.py install develop ``` -4. Test the installation. +5. Test the installation (use torchx-nightly for 3.11; for 3.12, torchx currently doesn't work). ``` GPU mode @@ -93,5 +106,32 @@ We are currently iterating on the setup experience. For now, we provide manual i 5. If you want to run a more complex example, please take a look at the torchrec [DLRM example](https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/dlrm_main.py). +## Contributing + +See [CONTRIBUTING.md](https://github.com/pytorch/torchrec/blob/main/CONTRIBUTING.md) for details about contributing to TorchRec! + +## Citation + +If you're using TorchRec, please refer to BibTeX entry to cite this work: +``` +@inproceedings{10.1145/3523227.3547387, +author = {Ivchenko, Dmytro and Van Der Staay, Dennis and Taylor, Colin and Liu, Xing and Feng, Will and Kindi, Rahul and Sudarshan, Anirudh and Sefati, Shahin}, +title = {TorchRec: a PyTorch Domain Library for Recommendation Systems}, +year = {2022}, +isbn = {9781450392785}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +url = {https://doi.org/10.1145/3523227.3547387}, +doi = {10.1145/3523227.3547387}, +abstract = {Recommendation Systems (RecSys) comprise a large footprint of production-deployed AI today. The neural network-based recommender systems differ from deep learning models in other domains in using high-cardinality categorical sparse features that require large embedding tables to be trained. In this talk we introduce TorchRec, a PyTorch domain library for Recommendation Systems. This new library provides common sparsity and parallelism primitives, enabling researchers to build state-of-the-art personalization models and deploy them in production. In this talk we cover the building blocks of the TorchRec library including modeling primitives such as embedding bags and jagged tensors, optimized recommender system kernels powered by FBGEMM, a flexible sharder that supports a veriety of strategies for partitioning embedding tables, a planner that automatically generates optimized and performant sharding plans, support for GPU inference and common modeling modules for building recommender system models. TorchRec library is currently used to train large-scale recommender models at Meta. We will present how TorchRec helped Meta’s recommender system platform to transition from CPU asynchronous training to accelerator-based full-sync training.}, +booktitle = {Proceedings of the 16th ACM Conference on Recommender Systems}, +pages = {482–483}, +numpages = {2}, +keywords = {information retrieval, recommender systems}, +location = {Seattle, WA, USA}, +series = {RecSys '22} +} +``` + ## License TorchRec is BSD licensed, as found in the [LICENSE](LICENSE) file. diff --git a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb new file mode 100644 index 000000000..015d216e4 --- /dev/null +++ b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb @@ -0,0 +1,2972 @@ +{ + "metadata": { + "custom": { + "cells": [], + "metadata": { + "accelerator": "GPU", + "colab": { + "background_execution": "on", + "collapsed_sections": [], + "machine_shape": "hm", + "name": "Torchrec Introduction.ipynb", + "provenance": [] + }, + "fileHeader": "", + "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00", + "interpreter": { + "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363" + }, + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3.7.13 ('torchrec': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 + }, + "indentAmount": 2, + "last_server_session_id": "e11f329f-b395-4702-9b33-449716ea422e", + "last_kernel_id": "b6fe1a08-1d4d-40cd-afe6-8352c4e42d25", + "last_base_url": "https://bento.edge.x2p.facebook.net/", + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "last_msg_id": "c02547e3-e4c072dc430f066c4d18479a_594", + "captumWidgetMessage": [], + "outputWidgetContext": [], + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "hBgIy9eYYx35", + "originalKey": "4766a371-bf6e-4342-98fb-16dde5255d73", + "outputsInitialized": false, + "language": "markdown", + "showInput": false + }, + "source": [ + "## **Open Source Installation** (For Reference)\n", + "Requirements:\n", + "- python >= 3.9\n", + "\n", + "We highly recommend CUDA when using TorchRec. If using CUDA:\n", + "- cuda >= 11.8\n", + "\n", + "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sFYvP95xaAER", + "originalKey": "27d22c43-9299-46ec-94f2-28a880546fe3", + "outputsInitialized": true, + "language": "python", + "customOutput": null, + "executionStartTime": 1726000131275, + "executionStopTime": 1726000131459, + "serverExecutionDuration": 2.2683702409267, + "requestMsgId": "27d22c43-9299-46ec-94f2-28a880546fe3" + }, + "source": [ + "# Install stable versions for best reliability\n", + "\n", + "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n", + "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n", + "!pip3 install torchmetrics==1.0.3\n", + "!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-4DFtQNDYao1", + "originalKey": "07e2a5ae-9ca2-45d7-af10-84d8e09ce91e", + "outputsInitialized": false, + "language": "markdown", + "showInput": false + }, + "source": [ + "# Intro to TorchRec\n", + "\n", + "### Embeddings\n", + "When building recommendation systems, categorical features typically have massive cardinalities, posts, users, ads, etc.\n", + "\n", + "In order to represent these entities and model these relationships, **embeddings** are used. In machine learning, **embeddings are a vectors of real numbers in a high-dimensional space used to represent meaning in complex data like words, images, or users**.\n", + "\n", + "\n", + "### Embeddings in RecSys\n", + "\n", + "Now you might wonder, how are these embeddings generated in the first place? Well, embeddings are represented as individual rows in an **Embedding Table**, also referred to as embedding weights. The reason for this is that embeddings/embedding table weights are trained just like all of the other weights of the model via gradient descent!\n", + "\n", + "Embedding tables are simply a large matrix for storing embeddings, with two dimensions (B, N), where\n", + "* B is the number of embeddings stored by the table\n", + "* N is the number of dimensions per embedding (N-dimensional embedding).\n", + "\n", + "\n", + "The inputs to embedding tables represent embedding lookups to retrieve the embedding for a specific index/row. In recommendation systems, such as those used in Meta, unique IDs are not only used for specific users, but also across entites like posts and ads to serve as lookup indices to respective embedding tables!\n", + "\n", + "Embeddings are trained in RecSys through the following process:\n", + "1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are hashed to the total size of the embedding table to prevent issues when the ID > # of rows\n", + "2. Embeddings are then retrieved and **pooled, such as taking the sum or mean of the embeddings**. This is required as there can be a variable # of embeddings per example while the model expects consistent shapes.\n", + "3. The **embeddings are used in conjunction with the rest of the model to produce a prediction**, such as [Click-Through Rate (CTR)](https://support.google.com/google-ads/answer/2615875?hl=en) for an Ad.\n", + "4. The loss is calculated with the prediction and the label for an example, and **all weights of the model are updated through gradient descent and backpropogation, including the embedding weights** that were associated with the example.\n", + "\n", + "These embeddings are crucial for representing categorical features, such as users, posts, and ads, in order to capture relationships and make good recommendations. Meta AI's [Deep learning recommendation model](https://arxiv.org/abs/1906.00091) (DLRM) paper talks more about the technical details of using embedding tables in RecSys.\n", + "\n", + "This tutorial will introduce the concept of embeddings, showcase TorchRec specific modules/datatypes, and depict how distributed training works with TorchRec." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "48b50971-aeab-4754-8cff-986496689f43", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000131464, + "executionStopTime": 1726000133971, + "serverExecutionDuration": 2349.9959111214, + "requestMsgId": "48b50971-aeab-4754-8cff-986496689f43", + "customOutput": null, + "outputsInitialized": true, + "output": { + "id": "1534047040582458" + }, + "id": "AbeT4W9xcso9" + }, + "source": [ + "import torch" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4b510f99-840d-4986-b635-33c21af48cf4", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "bjuDdEqocso-" + }, + "source": [ + "## Embeddings in PyTorch\n", + "[`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html): Embedding table where forward pass returns the embeddings themselves as is.\n", + "\n", + "[`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html): Embedding table where forward pass returns embeddings that are then pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**\n", + "\n", + "In this section, we will go over a very brief introduction with doing embedding lookups through passing in indices into the table. Check out the links for each for more sophisticated use cases and experiments!" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000133982, + "executionStopTime": 1726000134201, + "serverExecutionDuration": 31.60185739398, + "requestMsgId": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "933119035309629" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1X5C_Dnccso-", + "outputId": "616cc153-67ee-4dd6-b1ab-ee6ff6f44709" + }, + "source": [ + "num_embeddings, embedding_dim = 10, 4\n", + "\n", + "# Initialize our embedding table\n", + "weights = torch.rand(num_embeddings, embedding_dim)\n", + "print(\"Weights:\", weights)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Weights: tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n", + " [0.1830, 0.0326, 0.8241, 0.2995],\n", + " [0.7328, 0.0531, 0.9528, 0.0592],\n", + " [0.7800, 0.1797, 0.0167, 0.7401],\n", + " [0.4837, 0.2052, 0.3360, 0.9656],\n", + " [0.7887, 0.3066, 0.0956, 0.3344],\n", + " [0.5904, 0.8541, 0.5963, 0.2800],\n", + " [0.5751, 0.4341, 0.6218, 0.4101],\n", + " [0.6881, 0.5363, 0.4747, 0.2301],\n", + " [0.6088, 0.1060, 0.1100, 0.7290]])\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b2f21375-8d36-487f-b0c3-ff8a5df950a4", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000134203, + "executionStopTime": 1726000134366, + "serverExecutionDuration": 8.956927806139, + "requestMsgId": "b2f21375-8d36-487f-b0c3-ff8a5df950a4", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "831419729143778" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bxszzeGdcso-", + "outputId": "88f21deb-c2f1-4894-975e-fdac41436b36" + }, + "source": [ + "# Pass in pre generated weights just for example, typically weights are randomly initialized\n", + "embedding_collection = torch.nn.Embedding(\n", + " num_embeddings, embedding_dim, _weight=weights\n", + ")\n", + "embedding_bag_collection = torch.nn.EmbeddingBag(\n", + " num_embeddings, embedding_dim, _weight=weights\n", + ")\n", + "\n", + "# Print out the tables, we should see the same weights as above\n", + "print(\"Embedding Collection Table: \", embedding_collection.weight)\n", + "print(\"Embedding Bag Collection Table: \", embedding_bag_collection.weight)\n", + "\n", + "# Lookup rows (ids for embedding ids) from the embedding tables\n", + "# 2D tensor with shape (batch_size, ids for each batch)\n", + "ids = torch.tensor([[1, 3]])\n", + "print(\"Input row IDS: \", ids)" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Embedding Collection Table: Parameter containing:\n", + "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n", + " [0.1830, 0.0326, 0.8241, 0.2995],\n", + " [0.7328, 0.0531, 0.9528, 0.0592],\n", + " [0.7800, 0.1797, 0.0167, 0.7401],\n", + " [0.4837, 0.2052, 0.3360, 0.9656],\n", + " [0.7887, 0.3066, 0.0956, 0.3344],\n", + " [0.5904, 0.8541, 0.5963, 0.2800],\n", + " [0.5751, 0.4341, 0.6218, 0.4101],\n", + " [0.6881, 0.5363, 0.4747, 0.2301],\n", + " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n", + "Embedding Bag Collection Table: Parameter containing:\n", + "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n", + " [0.1830, 0.0326, 0.8241, 0.2995],\n", + " [0.7328, 0.0531, 0.9528, 0.0592],\n", + " [0.7800, 0.1797, 0.0167, 0.7401],\n", + " [0.4837, 0.2052, 0.3360, 0.9656],\n", + " [0.7887, 0.3066, 0.0956, 0.3344],\n", + " [0.5904, 0.8541, 0.5963, 0.2800],\n", + " [0.5751, 0.4341, 0.6218, 0.4101],\n", + " [0.6881, 0.5363, 0.4747, 0.2301],\n", + " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n", + "Input row IDS: tensor([[1, 3]])\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "cb5c5906-e9a6-4315-b860-b263e08989be", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000134369, + "executionStopTime": 1726000134545, + "serverExecutionDuration": 5.9817284345627, + "requestMsgId": "cb5c5906-e9a6-4315-b860-b263e08989be", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "2201664893536578" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xkedJeTOcso_", + "outputId": "46215f2b-03ad-421b-f873-78b2be0df4d4" + }, + "source": [ + "embeddings = embedding_collection(ids)\n", + "\n", + "# Print out the embedding lookups\n", + "# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above\n", + "print(\"Embedding Collection Results: \")\n", + "print(embeddings)\n", + "print(\"Shape: \", embeddings.shape)" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Embedding Collection Results: \n", + "tensor([[[0.1830, 0.0326, 0.8241, 0.2995],\n", + " [0.7800, 0.1797, 0.0167, 0.7401]]], grad_fn=)\n", + "Shape: torch.Size([1, 2, 4])\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000134547, + "executionStopTime": 1726000134718, + "serverExecutionDuration": 7.8675262629986, + "requestMsgId": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "6449977515126116" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PmtJkxLccso_", + "outputId": "c33109b4-205e-492d-e0ca-94d2887ec6e7" + }, + "source": [ + "# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above\n", + "pooled_embeddings = embedding_bag_collection(ids)\n", + "\n", + "print(\"Embedding Bag Collection Results: \")\n", + "print(pooled_embeddings)\n", + "print(\"Shape: \", pooled_embeddings.shape)\n", + "\n", + "# nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc.)\n", + "# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection\n", + "print(\"Mean: \", torch.mean(embedding_collection(ids), dim=1))" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Embedding Bag Collection Results: \n", + "tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n", + "Shape: torch.Size([1, 4])\n", + "Mean: tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4643305e-2770-40cf-afc6-e64cd3f51063", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "SuCV1cJ8cso_" + }, + "source": [ + "Congratulations! Now you have a basic understanding on how to use embedding tables --- one of the foundations of modern recommendation systems! These tables represent entities and their relationships. For example, the relationship between a given user and the pages & posts they have liked." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "7dfcffeb-c7c0-4d74-9dba-569c1d882898", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "QIuAYSZ5cso_" + }, + "source": [ + "# TorchRec\n", + "\n", + "Now you know how to use embedding tables, one of the foundations of modern recommendation systems! These tables represent entities and relationships, such as users, pages, posts, etc. Given that these entities are always increasing, a **hash** function is typically applied to make sure the ids are within the bounds of a certain embedding table. However, in order to represent a vast amount of entities and reduce hash collisions, these tables can become quite massive (think about # of ads for example). In fact, these tables can become so massive that they won't be able to fit on 1 GPU, even with 80G of memory!\n", + "\n", + "In order to train models with massive embedding tables, sharding these tables across GPUs is required, which then introduces a whole new set of problems/opportunities in parallelism and optimization. Luckily, we have the TorchRec library that has encountered, consolidated, and addressed many of these concerns. TorchRec serves as a **library that provides primitives for large scale distributed embeddings**.\n", + "\n", + "From here on out, we will explore the major features of the TorchRec library. We will start with torch.nn.Embedding and will extend that to custom TorchRec modules, explore distributed training environment with generating a sharding plan for embeddings, look at inherent TorchRec optimizations, and extend the model to be ready for inference in C++. Below is a quick outline of what the journey will consist of - buckle in!\n", + "\n", + "1. TorchRec Modules and DataTypes\n", + "2. Distributed Training, Sharding, and Optimizations\n", + "3. Inference\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "8395ed9c-8336-4686-8e73-cb815b808f2a", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000134724, + "executionStopTime": 1726000139238, + "serverExecutionDuration": 4317.9145939648, + "requestMsgId": "8395ed9c-8336-4686-8e73-cb815b808f2a", + "outputsInitialized": true, + "customOutput": null, + "id": "5vzmNV0IcspA" + }, + "source": [ + "import torchrec" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "0c95b385-e07a-43e1-aaeb-31f66deb5b35", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "42PwMZnNcspA" + }, + "source": [ + "## TorchRec Modules and Datatypes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZdSUWBRxoP8R", + "originalKey": "309c4d38-8f19-46d9-a8bb-7d3d1c166e84", + "outputsInitialized": false, + "language": "markdown", + "showInput": false + }, + "source": [ + "### From EmbeddingBag to EmbeddingBagCollection\n", + "We have already explored [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).\n", + "\n", + "TorchRec extends these modules by creating collections of embeddings, in other words modules that can have multiple embedding tables, with [`EmbeddingCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection) and [`EmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection). We will use `EmbeddingBagCollection` to represent a group of EmbeddingBags.\n", + "\n", + "Here, we create an EmbeddingBagCollection (EBC) with two embedding bags, 1 representing **products** and 1 representing **users**. Each table, `product_table` and `user_table`, is represented by 64 dimension embedding of size 4096." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Iz_GZDp_oQ19", + "originalKey": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e", + "outputsInitialized": true, + "language": "python", + "customOutput": null, + "executionStartTime": 1726000139247, + "executionStopTime": 1726000139433, + "serverExecutionDuration": 13.643965125084, + "requestMsgId": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e", + "output": { + "id": "1615870128957785" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "cec92f4f-d9eb-464e-8e22-26fd910bf8c1" + }, + "source": [ + "ebc = torchrec.EmbeddingBagCollection(\n", + " device=\"cpu\",\n", + " tables=[\n", + " torchrec.EmbeddingBagConfig(\n", + " name=\"product_table\",\n", + " embedding_dim=64,\n", + " num_embeddings=4096,\n", + " feature_names=[\"product\"],\n", + " pooling=torchrec.PoolingType.SUM,\n", + " ),\n", + " torchrec.EmbeddingBagConfig(\n", + " name=\"user_table\",\n", + " embedding_dim=64,\n", + " num_embeddings=4096,\n", + " feature_names=[\"user\"],\n", + " pooling=torchrec.PoolingType.SUM,\n", + " )\n", + " ]\n", + ")\n", + "print(ebc.embedding_bags)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ModuleDict(\n", + " (product_table): EmbeddingBag(4096, 64, mode='sum')\n", + " (user_table): EmbeddingBag(4096, 64, mode='sum')\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "c587a298-4d38-4a69-89a2-5d5c4a26cc2c", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "xjcA0Di1cspA" + }, + "source": [ + "Let’s inspect the forward method for EmbeddingBagcollection and the module’s inputs and outputs." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c9d2717b-b753-4e0b-97bd-1596123d081d", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000139437, + "executionStopTime": 1726000139616, + "serverExecutionDuration": 6.011176854372, + "requestMsgId": "c9d2717b-b753-4e0b-97bd-1596123d081d", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "398959426640405" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UuIrEWupcspA", + "outputId": "300c86c2-82c1-4657-fa1a-6a319eb40177" + }, + "source": [ + "import inspect\n", + "\n", + "# Let's look at the EmbeddingBagCollection forward method\n", + "# What is a KeyedJaggedTensor and KeyedTensor?\n", + "print(inspect.getsource(ebc.forward))" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:\n", + " \"\"\"\n", + " Args:\n", + " features (KeyedJaggedTensor): KJT of form [F X B X L].\n", + "\n", + " Returns:\n", + " KeyedTensor\n", + " \"\"\"\n", + " if is_non_strict_exporting() and not torch.jit.is_scripting():\n", + " return self._non_strict_exporting_forward(features)\n", + " flat_feature_names: List[str] = []\n", + " for names in self._feature_names:\n", + " flat_feature_names.extend(names)\n", + " inverse_indices = reorder_inverse_indices(\n", + " inverse_indices=features.inverse_indices_or_none(),\n", + " feature_names=flat_feature_names,\n", + " )\n", + " pooled_embeddings: List[torch.Tensor] = []\n", + " feature_dict = features.to_dict()\n", + " for i, embedding_bag in enumerate(self.embedding_bags.values()):\n", + " for feature_name in self._feature_names[i]:\n", + " f = feature_dict[feature_name]\n", + " res = embedding_bag(\n", + " input=f.values(),\n", + " offsets=f.offsets(),\n", + " per_sample_weights=f.weights() if self._is_weighted else None,\n", + " ).float()\n", + " pooled_embeddings.append(res)\n", + " return KeyedTensor(\n", + " keys=self._embedding_names,\n", + " values=process_pooled_embeddings(\n", + " pooled_embeddings=pooled_embeddings,\n", + " inverse_indices=inverse_indices,\n", + " ),\n", + " length_per_key=self._lengths_per_embedding,\n", + " )\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "d6b9bfc2-544d-499f-ad61-d7471b819f8a", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "C_UAtHsMcspA" + }, + "source": [ + "### TorchRec Input/Output Data Types\n", + "TorchRec has distinct data types for input and output of its modules: `JaggedTensor`, `KeyedJaggedTensor`, and `KeyedTensor`. Now you might ask, why create new datatypes to represent sparse features? To answer that question, we must understand how sparse features are represented in code.\n", + "\n", + "Sparse features are otherwise known as `id_list_feature` and `id_score_list_feature`, and are the **IDs** that will be used as indices to an embedding table to retrieve the embedding for that ID. To give a very simple example, imagine a single sparse feature being Ads that a user interacted with. The input itself would be a set of Ad IDs that a user interacted with, and the embeddings retrieved would be a semantic representation of those Ads. The tricky part of representing these features in code is that in each input example, **the number of IDs is variable**. 1 day a user might have interacted with only 1 ad while the next day they interact with 3.\n", + "\n", + "A simple representation is shown below, where we have a `lengths` tensor denoting how many indices are in an example for a batch and a `values` tensor containing the indices themselves.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "13225ead-a798-4db2-8de6-1c13a758d676", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000139620, + "executionStopTime": 1726000139790, + "serverExecutionDuration": 3.692839294672, + "requestMsgId": "13225ead-a798-4db2-8de6-1c13a758d676", + "outputsInitialized": true, + "customOutput": null, + "id": "RB77aL08cspA" + }, + "source": [ + "# Batch Size 2\n", + "# 1 ID in example 1, 2 IDs in example 2\n", + "id_list_feature_lengths = torch.tensor([1, 2])\n", + "\n", + "# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2\n", + "id_list_feature_values = torch.tensor([5, 7, 1])" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "65d31fca-7b7f-4c0f-9ca2-56e07243a5c0", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "aKmgGqdNcspA" + }, + "source": [ + "Let’s look at the offsets as well as what is contained in each Batch" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9510cebd-1875-461e-9243-53928632abfa", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000139794, + "executionStopTime": 1726000139966, + "serverExecutionDuration": 6.6289491951466, + "requestMsgId": "9510cebd-1875-461e-9243-53928632abfa", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "869913611744322" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "t5T5S8_mcspB", + "outputId": "87e78b11-7497-4387-c0bc-b4d277ba8ab3" + }, + "source": [ + "# Lengths can be converted to offsets for easy indexing of values\n", + "id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)\n", + "\n", + "print(\"Offsets: \", id_list_feature_offsets)\n", + "print(\"First Batch: \", id_list_feature_values[: id_list_feature_offsets[0]])\n", + "print(\n", + " \"Second Batch: \",\n", + " id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],\n", + ")" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Offsets: tensor([1, 3])\n", + "First Batch: tensor([5])\n", + "Second Batch: tensor([7, 1])\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000139968, + "executionStopTime": 1726000140161, + "serverExecutionDuration": 7.3191449046135, + "requestMsgId": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1254783359215069" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2OOK2BBecspB", + "outputId": "b27c1547-fbb7-47c4-efb6-48aff3300d1a" + }, + "source": [ + "from torchrec import JaggedTensor\n", + "\n", + "# JaggedTensor is just a wrapper around lengths/offsets and values tensors!\n", + "jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)\n", + "\n", + "# Automatically compute offsets from lengths\n", + "print(\"Offsets: \", jt.offsets())\n", + "\n", + "# Convert to list of values\n", + "print(\"List of Values: \", jt.to_dense())\n", + "\n", + "# __str__ representation\n", + "print(jt)" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Offsets: tensor([0, 1, 3])\n", + "List of Values: [tensor([5]), tensor([7, 1])]\n", + "JaggedTensor({\n", + " [[5], [7, 1]]\n", + "})\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "ad069058-2329-4ab9-bee8-60775ead4c33", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000140165, + "executionStopTime": 1726000140355, + "serverExecutionDuration": 10.361641645432, + "requestMsgId": "ad069058-2329-4ab9-bee8-60775ead4c33", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "530006499497328" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fs10Fxu2cspB", + "outputId": "3bc754c2-ac30-4c01-b0c8-7f98eedb9c52" + }, + "source": [ + "from torchrec import KeyedJaggedTensor\n", + "\n", + "# JaggedTensor represents IDs for 1 feature, but we have multiple features in an EmbeddingBagCollection\n", + "# That's where KeyedJaggedTensor comes in! KeyedJaggedTensor is just multiple JaggedTensors for multiple id_list_feature_offsets\n", + "# From before, we have our two features \"product\" and \"user\". Let's create JaggedTensors for both!\n", + "\n", + "product_jt = JaggedTensor(\n", + " values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])\n", + ")\n", + "user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))\n", + "\n", + "# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?\n", + "kjt = KeyedJaggedTensor.from_jt_dict({\"product\": product_jt, \"user\": user_jt})\n", + "\n", + "# Look at our feature keys for the KeyedJaggedTensor\n", + "print(\"Keys: \", kjt.keys())\n", + "\n", + "# Look at the overall lengths for the KeyedJaggedTensor\n", + "print(\"Lengths: \", kjt.lengths())\n", + "\n", + "# Look at all values for KeyedJaggedTensor\n", + "print(\"Values: \", kjt.values())\n", + "\n", + "# Can convert KJT to dictionary representation\n", + "print(\"to_dict: \", kjt.to_dict())\n", + "\n", + "# KeyedJaggedTensor(KJT) string representation\n", + "print(kjt)\n", + "\n", + "# Q2: What are the offsets for the KeyedJaggedTensor?" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Keys: ['product', 'user']\n", + "Lengths: tensor([3, 1, 2, 2])\n", + "Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])\n", + "to_dict: {'product': , 'user': }\n", + "KeyedJaggedTensor({\n", + " \"product\": [[1, 2, 1], [5]],\n", + " \"user\": [[2, 3], [4, 1]]\n", + "})\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b13fdf10-45a7-4e57-b50e-cc18547a715b", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000140357, + "executionStopTime": 1726000140549, + "serverExecutionDuration": 17.695877701044, + "requestMsgId": "b13fdf10-45a7-4e57-b50e-cc18547a715b", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "496557126663787" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JeLwyCNRcspB", + "outputId": "0723906d-0aba-4d48-e9a5-7ac618d711c5" + }, + "source": [ + "# Now we can run a forward pass on our ebc from before\n", + "result = ebc(kjt)\n", + "result" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "57a01464-de39-4bfb-8355-83cd97e519c0", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000140552, + "executionStopTime": 1726000140732, + "serverExecutionDuration": 6.0368701815605, + "requestMsgId": "57a01464-de39-4bfb-8355-83cd97e519c0", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1457290878317732" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R2K4v2vqcspB", + "outputId": "c507e47e-32bc-4440-8c8a-2b3ea334467c" + }, + "source": [ + "# Result is a KeyedTensor, which contains a list of the feature names and the embedding results\n", + "print(result.keys())\n", + "\n", + "# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined\n", + "# 128 for dimension of embedding. If you look at where we initialized the EmbeddingBagCollection, we have two tables \"product\" and \"user\" of dimension 64 each\n", + "# meaning emebddings for both features are of size 64. 64 + 64 = 128\n", + "print(result.values().shape)\n", + "\n", + "# Nice to_dict method to determine the embeddings that belong to each feature\n", + "result_dict = result.to_dict()\n", + "for key, embedding in result_dict.items():\n", + " print(key, embedding.shape)" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "['product', 'user']\n", + "torch.Size([2, 128])\n", + "product torch.Size([2, 64])\n", + "user torch.Size([2, 64])\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "d0fc8635-dac3-444b-978b-421b5d77b70c", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "EE-YYDv7cspB" + }, + "source": [ + "Congrats! Give yourself a pat on the back for making it this far." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "70816a78-7671-411c-814f-d2c98c3a912c", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "djLHn0CIcspB" + }, + "source": [ + "## Distributed Training and Sharding\n", + "Now that we have a grasp on TorchRec modules and data types, it's time to take it to the next level.\n", + "\n", + "Remember, TorchRec's main purpose is to provide primitives for distributed embeddings. So far, we've only worked with embedding tables on 1 device. This has been possible given how small the embedding tables have been, but in a production setting this isn't generally the case. Embedding tables often get massive, where 1 table can't fit on a single GPU, creating the requirement for multiple devices and a distributed environment\n", + "\n", + "In this section, we will explore setting up a distributed environment, exactly how actual production training is done, and explore sharding embedding tables, all with Torchrec.\n", + "\n", + "**This section will also only use 1 gpu, though it will be treated in a distributed fashion. This is only a limitation for training, as training has a process per gpu. Inference does not run into this requirement**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4-v17rxkopQw", + "originalKey": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4", + "outputsInitialized": true, + "language": "python", + "customOutput": null, + "executionStartTime": 1726000140740, + "executionStopTime": 1726000142256, + "serverExecutionDuration": 1350.0418178737, + "requestMsgId": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4", + "output": { + "id": "1195358511578142" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "efa5689f-b794-41a6-8c06-86856e1e698a" + }, + "source": [ + "# Here we set up our torch distributed environment\n", + "# WARNING: You can only call this cell once, calling it again will cause an error\n", + "# as you can only initialize the process group once\n", + "\n", + "import os\n", + "\n", + "import torch.distributed as dist\n", + "\n", + "# Set up environment variables for distributed training\n", + "# RANK is which GPU we are on, default 0\n", + "os.environ[\"RANK\"] = \"0\"\n", + "# How many devices in our \"world\", since Bento can only handle 1 process, 1 GPU\n", + "os.environ[\"WORLD_SIZE\"] = \"1\"\n", + "# Localhost as we are training locally\n", + "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + "# Port for distributed training\n", + "os.environ[\"MASTER_PORT\"] = \"29500\"\n", + "\n", + "# Note - you will need a V100 or A100 to run tutorial as!\n", + "# nccl backend is for GPUs, gloo is for CPUs\n", + "dist.init_process_group(backend=\"gloo\")\n", + "\n", + "print(f\"Distributed environment initialized: {dist}\")" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Distributed environment initialized: \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "480e69dc-3e9d-4e86-b73c-950e18afb0f5", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "hQOjNci3cspB" + }, + "source": [ + "### Distributed Embeddings\n", + "\n", + "We have already worked with the main TorchRec module: `EmbeddingBagCollection`. We have examined how it works along with how data is represented in TorchRec. However, we have not yet explored one of the main parts of TorchRec, which is **distributed embeddings**.\n", + "\n", + "GPUs are the most popular choice for ML workloads by far today, as they are able to do magnitudes more floating point operations/s ([FLOPs](https://en.wikipedia.org/wiki/FLOPS)) than CPU. However, GPUs come with the limitation of scarce fast memory (HBM which is analgous to RAM for CPU), typically ~10s of GBs.\n", + "\n", + "A RecSys model can contain embedding tables that far exceed the memory limit for 1 GPU, hence the need for distribution of the embedding tables across multiple GPUs, otherwise known as **model parallel**. On the other hand, **data parallel** is where the entire model is replicated on each GPU, which each GPU taking in a distinct batch of data for training, syncing gradients on the backwards pass.\n", + "\n", + "Parts of the model that **require less compute but more memory (embeddings) are distributed with model parallel** while parts that **require more compute and less memory (dense layers, MLP, etc.) are distributed with data parallel**.\n", + "\n", + "\n", + "### Sharding\n", + "In order to distribute an embedding table, we split up the embedding table into parts and place those parts onto different devices, also known as “sharding”.\n", + "\n", + "There are many ways to shard embedding tables. The most common ways are:\n", + "* Table-Wise: the table is placed entirely onto one device\n", + "* Column-Wise: columns of embedding tables are sharded\n", + "* Row-Wise: rows of embedding tables are sharded\n", + "\n", + "\n", + "### Sharded Modules\n", + "While all of this seems like a lot to deal with and implement, you're in luck. **TorchRec provides all the primitives for easy distributed training/inference**! In fact, TorchRec modules have two corresponding classes for working with any TorchRec module in a distributed environment:\n", + "1. The module sharder: This class exposes a `shard` API that handles sharding a TorchRec Module, producing a sharded module.\n", + " * For `EmbeddingBagCollection`, the sharder is [`EmbeddingBagCollectionSharder`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder)\n", + "2. Sharded module: This class is a sharded variant of a TorchRec module. It has the same input/output as a the regular TorchRec module, but much more optimized and works in a distributed environment.\n", + " * For `EmbeddingBagCollection`, the sharded variant is [`ShardedEmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection)\n", + "\n", + "Every TorchRec module has an unsharded and sharded variant.\n", + "* The unsharded version is meant to be prototyped and experimented with\n", + "* The sharded version is meant to be used in a distributed environment for distributed training/inference.\n", + "\n", + "The sharded versions of TorchRec modules, for example EmbeddingBagCollection, will handle everything that is needed for Model Parallelism, such as communication between GPUs for distributing embeddings to the correct GPUs.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "eb2a064d-0b67-4cba-a199-c99573c7e6cd", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000142261, + "executionStopTime": 1726000142430, + "serverExecutionDuration": 8.3460621535778, + "requestMsgId": "eb2a064d-0b67-4cba-a199-c99573c7e6cd", + "customOutput": null, + "outputsInitialized": true, + "output": { + "id": "791089056311464" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FX65VcQ6cspB", + "outputId": "1aa3bc52-569b-46fb-8a94-cd1873e987ca" + }, + "source": [ + "# Refresher of our EmbeddingBagCollection module\n", + "ebc" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "EmbeddingBagCollection(\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): EmbeddingBag(4096, 64, mode='sum')\n", + " (user_table): EmbeddingBag(4096, 64, mode='sum')\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "1442636d-8617-4785-b40c-8544374253b6", + "showInput": true, + "customInput": null, + "language": "python", + "outputsInitialized": true, + "executionStartTime": 1726000142433, + "executionStopTime": 1726000142681, + "serverExecutionDuration": 4.4135116040707, + "requestMsgId": "1442636d-8617-4785-b40c-8544374253b6", + "customOutput": null, + "output": { + "id": "502189589096046" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1hSzTg4pcspC", + "outputId": "d7d86592-4fdc-4d0b-f2ba-a40a80af1fcf" + }, + "source": [ + "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n", + "from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\n", + "from torchrec.distributed.types import ShardingEnv\n", + "\n", + "# Corresponding sharder for EmbeddingBagCollection module\n", + "sharder = EmbeddingBagCollectionSharder()\n", + "\n", + "# ProcessGroup from torch.distributed initialized 2 cells above\n", + "pg = dist.GroupMember.WORLD\n", + "assert pg is not None, \"Process group is not initialized\"\n", + "\n", + "print(f\"Process Group: {pg}\")" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Process Group: \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "29cc17eb-9e2f-480b-aed2-60b15024fbf7", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "qU7A980qcspC" + }, + "source": [ + "### Planner\n", + "\n", + "Before we can show how sharding works, we must know about the **planner**, which helps us determine the best sharding configuration.\n", + "\n", + "Given a number of embedding tables and a number of ranks, there are many different sharding configurations that are possible. For example, given 2 embedding tables and 2 GPUs, you can:\n", + "* Place 1 table on each GPU\n", + "* Place both tables on a single GPU and no tables on the other\n", + "* Place certain rows/columns on each GPU\n", + "\n", + "Given all of these possibilities, we typically want a sharding configuration that is optimal for performance.\n", + "\n", + "That is where the planner comes in. The planner is able to determine given the # of embedding tables and the # of GPUs, what is the optimal configuration. Turns out, this is incredibly difficult to do manually, with tons of factors that engineers have to consider to ensure an optimal sharding plan. Luckily, TorchRec provides an auto planner when the planner is used. The TorchRec planner:\n", + "* assesses memory constraints of hardware,\n", + "* estimates compute based on memory fetches as embedding lookups,\n", + "* addresses data specific factors,\n", + "* considers other hardware specifics like bandwidth to generate an optimal sharding plan.\n", + "\n", + "In order to take into consideration all these variables, The TorchRec planner can take in [various amounts of data for embedding tables, constraints, hardware information, and topology](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155) to aid in generating the optimal sharding plan for a model, which is routinely provided across stacks\n", + "\n", + "\n", + "To learn more about sharding, see our [sharding tutorial](https://pytorch.org/tutorials/advanced/sharding.html)." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "64936203-2e59-4bc3-8d76-1b652b7891c2", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000142687, + "executionStopTime": 1726000143033, + "serverExecutionDuration": 145.92137932777, + "requestMsgId": "64936203-2e59-4bc3-8d76-1b652b7891c2", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1247084956198777" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PQeXnuAGcspC", + "outputId": "e6d366b7-f064-4e38-a477-fc2e0d4b9736" + }, + "source": [ + "# In our case, 1 GPU and compute on CUDA device\n", + "planner = EmbeddingShardingPlanner(\n", + " topology=Topology(\n", + " world_size=1,\n", + " compute_device=\"cuda\",\n", + " )\n", + ")\n", + "\n", + "# Run planner to get plan for sharding\n", + "plan = planner.collective_plan(ebc, [sharder], pg)\n", + "\n", + "print(f\"Sharding Plan generated: {plan}\")" + ], + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sharding Plan generated: module: \n", + "\n", + " param | sharding type | compute kernel | ranks\n", + "------------- | ------------- | -------------- | -----\n", + "product_table | table_wise | fused | [0] \n", + "user_table | table_wise | fused | [0] \n", + "\n", + " param | shard offsets | shard sizes | placement \n", + "------------- | ------------- | ----------- | -------------\n", + "product_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n", + "user_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "bbbbbf60-5691-4357-9943-4d7f8b2b1d5c", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "2TTLj_0PcspC" + }, + "source": [ + "### Planner Result\n", + "As you can see, when running the planner there is quite a bit of output above. We can see a ton of stats being calculated along with where our tables end up being placed.\n", + "\n", + "The result of running the planner is a static plan, which can be reused for sharding! This allows sharding to be static for production models instead of determining a new sharding plan everytime. Below, we use the sharding plan to finally generate our `ShardedEmbeddingBagCollection.`" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "533be12d-a3c5-4c9e-9351-7770251c8fa5", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000143037, + "executionStopTime": 1726000143259, + "serverExecutionDuration": 5.2368640899658, + "requestMsgId": "533be12d-a3c5-4c9e-9351-7770251c8fa5", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "901470115170971" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JIci5Gz6cspC", + "outputId": "32d3cb73-80b6-4646-9c1d-3cff5a498f86" + }, + "source": [ + "# The static plan that was generated\n", + "plan" + ], + "execution_count": 20, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ShardingPlan(plan={'': {'product_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None), 'user_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None)}})" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000143262, + "executionStopTime": 1726000147680, + "serverExecutionDuration": 4229.5375689864, + "requestMsgId": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1231077634880712" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2__Do2tqcspC", + "outputId": "7c19830a-9a11-4fbe-ddf3-8c3cb8a6c3b3" + }, + "source": [ + "env = ShardingEnv.from_process_group(pg)\n", + "\n", + "# Shard the EmbeddingBagCollection module using the EmbeddingBagCollectionSharder\n", + "sharded_ebc = sharder.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n", + "\n", + "print(f\"Sharded EBC Module: {sharded_ebc}\")" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sharded EBC Module: ShardedEmbeddingBagCollection(\n", + " (lookups): \n", + " GroupedPooledEmbeddingsLookup(\n", + " (_emb_modules): ModuleList(\n", + " (0): BatchedFusedEmbeddingBag(\n", + " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n", + " )\n", + " )\n", + " )\n", + " (_output_dists): \n", + " TwPooledEmbeddingDist()\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): Module()\n", + " (user_table): Module()\n", + " )\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "ErXXbYzJmVzI" + } + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "3ba44a6d-a6f7-4da2-83a6-e8ac974c64ac", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "QBLpkKYIcspC" + }, + "source": [ + "#### Awaitable\n", + "Remember that TorchRec is a highly optimized library for distributed embeddings. A concept that TorchRec introduces to enable higher performance for training on GPU is a [`LazyAwaitable`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable). You will see `LazyAwaitable` types as outputs of various sharded TorchRec modules. All a `LazyAwaitable` does is delay calculating some result as long as possible, and it does it by acting like an async type." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e450dc00-bd30-4bc2-8c71-4c01979b0948", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000147687, + "executionStopTime": 1726000147874, + "serverExecutionDuration": 9.098757058382, + "requestMsgId": "e450dc00-bd30-4bc2-8c71-4c01979b0948", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1236006950908310" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rwYzKwyNcspC", + "outputId": "a8979bb2-6fac-4db6-b997-c31962014e24" + }, + "source": [ + "from typing import List\n", + "\n", + "from torchrec.distributed.types import LazyAwaitable\n", + "\n", + "\n", + "# Demonstrate a LazyAwaitable type\n", + "class ExampleAwaitable(LazyAwaitable[torch.Tensor]):\n", + " def __init__(self, size: List[int]) -> None:\n", + " super().__init__()\n", + " self._size = size\n", + "\n", + " def _wait_impl(self) -> torch.Tensor:\n", + " return torch.ones(self._size)\n", + "\n", + "\n", + "awaitable = ExampleAwaitable([3, 2])\n", + "awaitable.wait()" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]])" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000147878, + "executionStopTime": 1726000154861, + "serverExecutionDuration": 6806.3651248813, + "requestMsgId": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1255627342282843" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cs41RfzGcspC", + "outputId": "ba7cfdbe-59c4-48c2-a767-d6f5f5cbb915" + }, + "source": [ + "kjt = kjt.to(\"cuda\")\n", + "output = sharded_ebc(kjt)\n", + "# The output of our sharded EmbeddingBagCollection module is a an Awaitable?\n", + "print(output)" + ], + "execution_count": 23, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000154865, + "executionStopTime": 1726000155069, + "serverExecutionDuration": 6.0432851314545, + "requestMsgId": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1057638405967561" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_1sdt75rcspG", + "outputId": "365cf6cb-adae-4b41-f68d-1a7acf5fe332" + }, + "source": [ + "kt = output.wait()\n", + "# Now we have out KeyedTensor after calling .wait()\n", + "# If you are confused as to why we have a KeyedTensor output,\n", + "# give yourself a refresher on the unsharded EmbeddingBagCollection module\n", + "print(type(kt))\n", + "\n", + "print(kt.keys())\n", + "\n", + "print(kt.values().shape)\n", + "\n", + "# Same output format as unsharded EmbeddingBagCollection\n", + "result_dict = kt.to_dict()\n", + "for key, embedding in result_dict.items():\n", + " print(key, embedding.shape)" + ], + "execution_count": 24, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "['product', 'user']\n", + "torch.Size([2, 128])\n", + "product torch.Size([2, 64])\n", + "user torch.Size([2, 64])\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4c464de0-20ef-4ef2-89e2-5d58ca224660", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "bEgB987CcspG" + }, + "source": [ + "### Anatomy of Sharded TorchRec modules\n", + "\n", + "We have now successfully sharded an EmbeddingBagCollection given a sharding plan that we generated! The sharded module has common APIs from TorchRec which abstract away distributed communication/compute amongst multiple GPUs. In fact, these APIs are highly optimized for performance in training and inference. **Below are the three common APIs for distributed training/inference** that are provided by TorchRec:\n", + "\n", + "1. **input_dist**: Handles distributing inputs from GPU to GPU\n", + "\n", + "2. **lookups**: Does the actual embedding lookup in an optimized, batched manner using FBGEMM TBE (more on this later)\n", + "\n", + "3. **output_dist**: Handles distributing outputs from GPU to GPU\n", + "\n", + "The distribution of inputs/outputs is done through [NCCL Collectives](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html), namely [All-to-Alls](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all), which is where all GPUs send/receive data to and from one another. TorchRec interfaces with PyTorch distributed for collectives and provides clean abstractions to the end users, removing the concern for the lower level details.\n", + "\n", + "\n", + "The backwards pass does all of these collectives but in the reverse order for distribution of gradients. input_dist, lookup, and output_dist all depend on the sharding scheme. Since we sharded in a table-wise fashion, these APIs are modules that are constructed by [TwPooledEmbeddingSharding](https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding).\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "03e6e163-af3a-4443-a5a8-3f877fc401d2", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000155075, + "executionStopTime": 1726000155253, + "serverExecutionDuration": 5.8192722499371, + "requestMsgId": "03e6e163-af3a-4443-a5a8-3f877fc401d2", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1042737524520351" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "O2ptES89cspG", + "outputId": "2b801648-6501-4463-d743-4887da340974" + }, + "source": [ + "sharded_ebc" + ], + "execution_count": 25, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ShardedEmbeddingBagCollection(\n", + " (lookups): \n", + " GroupedPooledEmbeddingsLookup(\n", + " (_emb_modules): ModuleList(\n", + " (0): BatchedFusedEmbeddingBag(\n", + " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n", + " )\n", + " )\n", + " )\n", + " (_input_dists): \n", + " TwSparseFeaturesDist(\n", + " (_dist): KJTAllToAll()\n", + " )\n", + " (_output_dists): \n", + " TwPooledEmbeddingDist(\n", + " (_dist): PooledEmbeddingsAllToAll()\n", + " )\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): Module()\n", + " (user_table): Module()\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000155256, + "executionStopTime": 1726000155442, + "serverExecutionDuration": 5.3565315902233, + "requestMsgId": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1063399165221115" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PHjJt3BQcspG", + "outputId": "aaedadd8-da43-4225-d5f8-a7a43fd0250a" + }, + "source": [ + "# Distribute input KJTs to all other GPUs and receive KJTs\n", + "sharded_ebc._input_dists" + ], + "execution_count": 26, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[TwSparseFeaturesDist(\n", + " (_dist): KJTAllToAll()\n", + " )]" + ] + }, + "metadata": {}, + "execution_count": 26 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "88abe892-1ed1-4806-84ad-35f43247a772", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000155445, + "executionStopTime": 1726000155695, + "serverExecutionDuration": 5.3521953523159, + "requestMsgId": "88abe892-1ed1-4806-84ad-35f43247a772", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1513800839249249" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jrEXMc7TcspG", + "outputId": "81ab40a5-135b-494c-f2bc-91be16a338cc" + }, + "source": [ + "# Distribute output embeddingts to all other GPUs and receive embeddings\n", + "sharded_ebc._output_dists" + ], + "execution_count": 27, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[TwPooledEmbeddingDist(\n", + " (_dist): PooledEmbeddingsAllToAll()\n", + " )]" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2eaf16f1-ac14-4f7a-b443-e707ff85c3f0", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "C2jfo5ilcspH" + }, + "source": [ + "### Optimizing Embedding Lookups\n", + "\n", + "In performing lookups for a collection of embedding tables, a trivial solution would be to iterate through all the `nn.EmbeddingBags` and do a lookup per table. This is exactly what the standard, unsharded TorchRec's `EmbeddingBagCollection` does. However, while this solution is simple, it is extremely slow.\n", + "\n", + "[FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) is a library that provides GPU operators (otherewise known as kernels) that are very optimized. One of these operators is known as **Table Batched Embedding** (TBE), provides two major optimizations:\n", + "\n", + "* Table batching, which allows you to look up multiple embeddings with one kernel call.\n", + "* Optimizer Fusion, which allows the module to update itself given the canonical pytorch optimizers and arguments.\n", + "\n", + "The `ShardedEmbeddingBagCollection` uses the FBGEMM TBE as the lookup instead of traditional `nn.EmbeddingBags` for optimized embedding lookups." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000155699, + "executionStopTime": 1726000155879, + "serverExecutionDuration": 5.0756596028805, + "requestMsgId": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "911093750838903" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1GoHWI6OcspH", + "outputId": "cd67815b-00bd-4a30-89cf-7b5d9c7051e9" + }, + "source": [ + "sharded_ebc._lookups" + ], + "execution_count": 28, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[GroupedPooledEmbeddingsLookup(\n", + " (_emb_modules): ModuleList(\n", + " (0): BatchedFusedEmbeddingBag(\n", + " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n", + " )\n", + " )\n", + " )]" + ] + }, + "metadata": {}, + "execution_count": 28 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "f2b31d78-81a9-426f-b017-ca8404383939", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "1zcbZX1lcspH" + }, + "source": [ + "### DistributedModelParallel\n", + "\n", + "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/model-parallel-api-reference.html#model-parallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n", + "\n", + "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n", + "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n", + "\n", + "DMP takes in everything that we've just experimented with, like a static sharding plan, a list of sharders, etc. However, it also has some nice defaults to seamlessly shard a TorchRec model. In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000155883, + "executionStopTime": 1726000156073, + "serverExecutionDuration": 7.8761726617813, + "requestMsgId": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1207953610328397" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ypVUDwpzcspH", + "outputId": "26a7a957-c231-459a-dfc8-f0c1cd6f697e" + }, + "source": [ + "ebc" + ], + "execution_count": 29, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "EmbeddingBagCollection(\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): EmbeddingBag(4096, 64, mode='sum')\n", + " (user_table): EmbeddingBag(4096, 64, mode='sum')\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "73fec38d-947a-49d5-a2ba-61e3828b7117", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000156075, + "executionStopTime": 1726000156438, + "serverExecutionDuration": 165.43522849679, + "requestMsgId": "73fec38d-947a-49d5-a2ba-61e3828b7117", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1838328716783594" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5EdlyWAycspH", + "outputId": "05e90aa2-cb83-4ddf-9da5-6aa31d6da278" + }, + "source": [ + "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))" + ], + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "f8d87a4e-6a7a-4a02-92f9-9baa794266af", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000156441, + "executionStopTime": 1726000156665, + "serverExecutionDuration": 6.8417005240917, + "requestMsgId": "f8d87a4e-6a7a-4a02-92f9-9baa794266af", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1059040285804352" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b5NgRErjcspH", + "outputId": "8f4de40c-3a3b-43e5-d645-814bf03dab0b" + }, + "source": [ + "out = model(kjt)\n", + "out.wait()" + ], + "execution_count": 31, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 31 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e7e02648-dee7-4b3a-8953-47e8b8771c3b", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000156669, + "executionStopTime": 1726000156885, + "serverExecutionDuration": 5.4804161190987, + "requestMsgId": "e7e02648-dee7-4b3a-8953-47e8b8771c3b", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "3346626825643095" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VJrSysjjcspH", + "outputId": "2920a3d0-dd96-43ea-ab0d-627de00d1e42" + }, + "source": [ + "model" + ], + "execution_count": 32, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "DistributedModelParallel(\n", + " (_dmp_wrapped_module): ShardedEmbeddingBagCollection(\n", + " (lookups): \n", + " GroupedPooledEmbeddingsLookup(\n", + " (_emb_modules): ModuleList(\n", + " (0): BatchedFusedEmbeddingBag(\n", + " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n", + " )\n", + " )\n", + " )\n", + " (_input_dists): \n", + " TwSparseFeaturesDist(\n", + " (_dist): KJTAllToAll()\n", + " )\n", + " (_output_dists): \n", + " TwPooledEmbeddingDist(\n", + " (_dist): PooledEmbeddingsAllToAll()\n", + " )\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): Module()\n", + " (user_table): Module()\n", + " )\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 32 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4b6171d5-ae60-4cc8-a47a-f01236c02e6c", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "BLM673eTcspH" + }, + "source": [ + "### Sharding Best Practices\n", + "\n", + "Currently, our configuration is only sharding on 1 GPU (or rank), which is trivial: just place all the tables on 1 GPUs memory. However, in real production use cases, embedding tables are **typically sharded on hundreds of GPUs**, with different sharding methods such as table-wise, row-wise, and column-wise. It is incredibly important to determine a proper sharding configuration (to prevent out of memory issues) while keeping it balanced not only in terms of memory but also compute for optimal performance." + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Adding in the Optimizer\n", + "\n", + "Remember that TorchRec modules are hyperoptimized for large scale distributed training. An important optimization is in regards to the optimizer. **TorchRec modules provide a seamless API to fuse the backwards pass and optimize step in training, providing a significant optimization in performance and decreasing the memory used, alongside granularity in assigning distinct optimizers to distinct model parameters.**\n", + "\n", + "#### Optimizer Classes\n", + "\n", + "TorchRec uses `CombinedOptimizer`, which contains a collection of `KeyedOptimizers`. A `CombinedOptimizer` effectively makes it easy to handle multiple optimizers for various sub groups in the model. A `KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized through a dictionary of parameters exposes the parameters. Each `TBE` module in a `EmbeddingBagCollection` will have it's own `KeyedOptimizer` which combines into one `CombinedOptimizer`.\n", + "\n", + "#### Fused optimizer in TorchRec\n", + "\n", + "Using `DistributedModelParallel`, the **optimizer is fused, which means that the optimizer update is done in the backward**. This is an optimization in TorchRec and FBGEMM, where the optimizer embedding gradients are not materialized and applied directly to the parameters. This brings significant memory savings as embedding gradients are typically size of the parameters themselves.\n", + "\n", + "You can, however, choose to make the optimizer `dense` which does not apply this optimization and let's you inspect the embedding gradients or apply computations to it as you wish. A dense optimizer in this case would be your [canonical PyTorch model training loop with optimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n", + "\n", + "Once the optimizer is created through `DistributedModelParallel`, you still need to manage an optimizer for the other parameters not associated with TorchRec embedding modules. To find the other parameters, use`in_backward_optimizer_filter(model.named_parameters())`.\n", + "\n", + "Apply an optimizer to those parameters as you would a normal Torch optimizer and combine this and the `model.fused_optimizer` into one `CombinedOptimizer` that you can use in your training loop to `zero_grad` and `step` through.\n", + "\n", + "#### Let's add an optimizer to our EmbeddingBagCollection\n", + "We will do this in two ways, which are equivalent, but give you options depending on your preferences:\n", + "1. Passing optimizer kwargs through fused parameters (fused_params) in sharder\n", + "2. Through `apply_optimizer_in_backward`\n", + "Note: `apply_optimizer_in_backward` converts the optimizer parameters to `fused_params` to pass to the `TBE` in the `EmbeddingBagCollection`/`EmbeddingCollection`." + ], + "metadata": { + "id": "zFhggkUCmd7f" + } + }, + { + "cell_type": "code", + "source": [ + "# Approach 1: passing optimizer kwargs through fused parameters\n", + "from torchrec.optim.optimizers import in_backward_optimizer_filter\n", + "from fbgemm_gpu.split_embedding_configs import EmbOptimType\n", + "\n", + "\n", + "# We initialize the sharder with\n", + "fused_params = {\n", + " \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n", + " \"learning_rate\": 0.02,\n", + " \"eps\": 0.002,\n", + "}\n", + "\n", + "# Init sharder with fused_params\n", + "sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n", + "\n", + "# We'll use same plan and unsharded EBC as before but this time with our new sharder\n", + "sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n", + "\n", + "# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty.\n", + "# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\n", + "print(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\n", + "print(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n", + "\n", + "print(f\"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h5BCEFidmnEw", + "outputId": "202c64f7-ae95-4b0d-9f53-16138a680d7d" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.01\n", + ")\n", + "Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.02\n", + ")\n", + "Type of optimizer: \n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n", + "import copy\n", + "# Approach 2: applying optimizer through apply_optimizer_in_backward\n", + "# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n", + "\n", + "# We can achieve the same result as we did in the previous\n", + "ebc_apply_opt = copy.deepcopy(ebc)\n", + "optimizer_kwargs = {\"lr\": 0.5}\n", + "\n", + "for name, param in ebc_apply_opt.named_parameters():\n", + " print(f\"{name=}\")\n", + " apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n", + "\n", + "sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n", + "\n", + "# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\n", + "print(sharded_ebc_apply_opt.fused_optimizer)\n", + "print(type(sharded_ebc_apply_opt.fused_optimizer))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T-xx724MmoKv", + "outputId": "0f58fb18-f423-4c84-ee57-d37bdba28eb8" + }, + "execution_count": 34, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "name='embedding_bags.product_table.weight'\n", + "name='embedding_bags.user_table.weight'\n", + ": EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.5\n", + ")\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":1: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.\n", + " from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# We can also check through the filter other parameters that aren't associated with the \"fused\" optimizer(s)\n", + "# Pratically, just non TorchRec module parameters. Since our module is just a TorchRec EBC\n", + "# there are no other parameters that aren't associated with TorchRec\n", + "print(\"Non Fused Model Parameters:\")\n", + "print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UEyhlmbmlwsW", + "outputId": "f6219673-d14d-444e-a451-98f33ddeb54d" + }, + "execution_count": 35, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Non Fused Model Parameters:\n", + "dict_keys([])\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Here we do a dummy backwards call and see that parameter updates for fused\n", + "# optimizers happen as a result of the backward pass\n", + "\n", + "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n", + "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n", + "print(f\"First Iteration Loss: {loss}\")\n", + "\n", + "loss.backward()\n", + "\n", + "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n", + "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n", + "# We don't call an optimizer.step(), so for the loss to have changed here,\n", + "# that means that the gradients were somehow updated, which is what the\n", + "# fused optimizer automatically handles for us\n", + "print(f\"Second Iteration Loss: {loss}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Bga-zM2OfnMW", + "outputId": "6c5d45f2-c479-4932-b39d-5ff8abe27d3c" + }, + "execution_count": 36, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "First Iteration Loss: 255.94378662109375\n", + "Second Iteration Loss: 245.72166442871094\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "e3bdc895-54c4-4fc6-9175-28dd75021c6a", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "Xc-RUDwDcspH" + }, + "source": [ + "## Inference\n", + "\n", + "Now that we are able to train distributed embeddings, how can we take the trained model and optimize it for inference? Inference is typically very sensitive to **performance and size of the model**. Running just the trained model in a Python environment is incredibly inefficient. There are two key differences between inference and training environments:\n", + "* **Quantization**: Inference models are typically quantized, where model parameters lose precision for lower latency in predictions and reduced model size. For example FP32 (4 bytes) in trained model to INT8 (1 byte) for each embedding weight. This is also necessary given the vast scale of embedding tables, as we want to use as few devices as possible for inference to minimize latency.\n", + "* **C++ environment**: Inference latency is a big deal, so in order to ensure ample performance, the model is typically ran in a C++ environment (along with situations where we don't have a Python runtime, like on device)\n", + "\n", + "TorchRec provides primitives for converting a TorchRec model into being inference ready with:\n", + "* APIs for quantizing the model, introducing optimizations automatically with FBGEMM TBE\n", + "* sharding embeddings for distributed inference\n", + "* compiling the model to [TorchScript](https://pytorch.org/docs/stable/jit.html) (compatible in C++)\n", + "\n", + "In this section, we will go over this entire workflow of:\n", + "* Quantizing the model\n", + "* Sharding the quantized model\n", + "* Compiling the sharded quantized model into TorchScript" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "aae8ef10-f7a4-421a-b71c-f177ff74e96a", + "showInput": true, + "customInput": null, + "language": "python", + "outputsInitialized": true, + "executionStartTime": 1726000156892, + "executionStopTime": 1726000157069, + "serverExecutionDuration": 7.4504055082798, + "requestMsgId": "aae8ef10-f7a4-421a-b71c-f177ff74e96a", + "customOutput": null, + "output": { + "id": "456742254014129" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8JypsUNmcspH", + "outputId": "0a745234-d316-4850-d84a-f08b0f045595" + }, + "source": [ + "ebc" + ], + "execution_count": 37, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "EmbeddingBagCollection(\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): EmbeddingBag(4096, 64, mode='sum')\n", + " (user_table): EmbeddingBag(4096, 64, mode='sum')\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 37 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "30694976-da54-48d6-922e-ca53f22c385f", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000157071, + "executionStopTime": 1726000157317, + "serverExecutionDuration": 2.9501467943192, + "requestMsgId": "30694976-da54-48d6-922e-ca53f22c385f", + "outputsInitialized": true, + "customOutput": null, + "id": "t2plfyrWcspH" + }, + "source": [ + "class InferenceModule(torch.nn.Module):\n", + " def __init__(self, ebc: torchrec.EmbeddingBagCollection):\n", + " super().__init__()\n", + " self.ebc_ = ebc\n", + "\n", + " def forward(self, kjt: KeyedJaggedTensor):\n", + " return self.ebc_(kjt)" + ], + "execution_count": 38, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "2a4a83f1-449d-493e-8f24-7c1975ecad9d", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000157320, + "executionStopTime": 1726000157494, + "serverExecutionDuration": 3.8229525089264, + "requestMsgId": "2a4a83f1-449d-493e-8f24-7c1975ecad9d", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1619365005294308" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5FRioGEmcspH", + "outputId": "e4d9efd3-2427-4602-bb28-2f30c4f3f985" + }, + "source": [ + "module = InferenceModule(ebc)\n", + "for name, param in module.named_parameters():\n", + " # Here, the parameters should still be FP32, as we are using a standard EBC\n", + " # FP32 is default, regularly used for training\n", + " print(name, param.shape, param.dtype)" + ], + "execution_count": 39, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32\n", + "ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "665352e2-208f-4951-8601-282d036b0e4e", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "OSTy4SU8cspH" + }, + "source": [ + "### Quantization\n", + "\n", + "As you can see above, the normal EBC contains embedding table weights as FP32 precision (32 bits for each weight). Here, we will use the TorchRec inference library to quantize the embedding weights of the model to INT8" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "796919b4-f9dd-4d14-a40e-f20668c8257b", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000157499, + "executionStopTime": 1726000157696, + "serverExecutionDuration": 14.22468572855, + "requestMsgId": "796919b4-f9dd-4d14-a40e-f20668c8257b", + "customOutput": null, + "outputsInitialized": true, + "output": { + "id": "560049189691202" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "oV-KPRqDcspH", + "outputId": "e91220a1-a26c-4f2a-e7c8-e91a3d56d8dc" + }, + "source": [ + "from torch import quantization as quant\n", + "from torchrec.modules.embedding_configs import QuantConfig\n", + "from torchrec.quant.embedding_modules import (\n", + " EmbeddingBagCollection as QuantEmbeddingBagCollection,\n", + ")\n", + "\n", + "\n", + "quant_dtype = torch.int8\n", + "\n", + "\n", + "qconfig = QuantConfig(\n", + " # dtype of the result of the embedding lookup, post activation\n", + " # torch.float generally for compatability with rest of the model\n", + " # as rest of the model here usually isn't quantized\n", + " activation=quant.PlaceholderObserver.with_args(dtype=torch.float),\n", + " # quantized type for embedding weights, aka parameters to actually quantize\n", + " weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),\n", + ")\n", + "qconfig_spec = {\n", + " # Map of module type to qconfig\n", + " torchrec.EmbeddingBagCollection: qconfig,\n", + "}\n", + "mapping = {\n", + " # Map of module type to quantized module type\n", + " torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,\n", + "}\n", + "\n", + "\n", + "module = InferenceModule(ebc)\n", + "\n", + "# Quantize the module\n", + "qebc = quant.quantize_dynamic(\n", + " module,\n", + " qconfig_spec=qconfig_spec,\n", + " mapping=mapping,\n", + " inplace=False,\n", + ")\n", + "\n", + "\n", + "print(f\"Quantized EBC: {qebc}\")" + ], + "execution_count": 40, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Quantized EBC: InferenceModule(\n", + " (ebc_): QuantizedEmbeddingBagCollection(\n", + " (_kjt_to_jt_dict): ComputeKJTToJTDict()\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): Module()\n", + " (user_table): Module()\n", + " )\n", + " )\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c1fdd88b-73af-47a8-8aec-4f9422051ee7", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000157700, + "executionStopTime": 1726000157862, + "serverExecutionDuration": 4.0535479784012, + "requestMsgId": "c1fdd88b-73af-47a8-8aec-4f9422051ee7", + "outputsInitialized": true, + "customOutput": null, + "id": "fAztesVacspI" + }, + "source": [ + "kjt = kjt.to(\"cpu\")" + ], + "execution_count": 41, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000157865, + "executionStopTime": 1726000158060, + "serverExecutionDuration": 9.1104581952095, + "requestMsgId": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "434299789062153" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Wnpwa0TmcspI", + "outputId": "88007466-88ce-4f6b-b7e8-22d042e5378b" + }, + "source": [ + "qebc(kjt)" + ], + "execution_count": 42, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 42 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "99559efa-baaa-4de1-91d3-7899f87fe659", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000158063, + "executionStopTime": 1726000158228, + "serverExecutionDuration": 3.4465603530407, + "requestMsgId": "99559efa-baaa-4de1-91d3-7899f87fe659", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "499581679596627" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UUs5fXNncspI", + "outputId": "79d814ed-7772-4bc3-ba84-f5f0d0a45e36" + }, + "source": [ + "# Once quantized, goes from parameters -> buffers, as no longer trainable\n", + "for name, buffer in qebc.named_buffers():\n", + " # The shapes of the tables should be the same but the dtype should be int8 now\n", + " # post quantization\n", + " print(name, buffer.shape, buffer.dtype)" + ], + "execution_count": 43, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8\n", + "ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2b1a9c89-b921-4a35-9f64-0c63b09a2579", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "fdM7UihocspI" + }, + "source": [ + "### Shard\n", + "\n", + "Here we perform sharding of the TorchRec quantized model. This is to ensure we are using the performant module through FBGEMM TBE. Here we are using one device to be consistent with training (1 TBE)." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "19c18bbb-6376-468a-a6dc-8346d30ceb48", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000158234, + "executionStopTime": 1726000158552, + "serverExecutionDuration": 108.51271077991, + "requestMsgId": "19c18bbb-6376-468a-a6dc-8346d30ceb48", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "882684747065056" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mha4FntncspI", + "outputId": "40a5557c-531f-4726-f06a-5f73546a8fe0" + }, + "source": [ + "from torchrec import distributed as trec_dist\n", + "from torchrec.distributed.shard import _shard_modules\n", + "\n", + "\n", + "sharded_qebc = _shard_modules(\n", + " module=qebc,\n", + " device=torch.device(\"cpu\"),\n", + " env=trec_dist.ShardingEnv.from_local(\n", + " 1,\n", + " 0,\n", + " ),\n", + ")\n", + "\n", + "\n", + "print(f\"Sharded Quantized EBC: {sharded_qebc}\")" + ], + "execution_count": 44, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sharded Quantized EBC: InferenceModule(\n", + " (ebc_): ShardedQuantEmbeddingBagCollection(\n", + " (lookups): \n", + " InferGroupedPooledEmbeddingsLookup()\n", + " (_output_dists): ModuleList()\n", + " (embedding_bags): ModuleDict(\n", + " (product_table): Module()\n", + " (user_table): Module()\n", + " )\n", + " (_input_dist_module): ShardedQuantEbcInputDist()\n", + " )\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "f00ae63f-0ac4-49c0-93fe-32d7fac76693", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000158555, + "executionStopTime": 1726000159111, + "serverExecutionDuration": 345.11629864573, + "requestMsgId": "f00ae63f-0ac4-49c0-93fe-32d7fac76693", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "876807203893705" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0iBD90t3cspI", + "outputId": "86a17997-aa50-4427-cf2b-55f4c0aef456" + }, + "source": [ + "sharded_qebc(kjt)" + ], + "execution_count": 45, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 45 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "897037bb-9d81-4a33-aea1-de1691217d41", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "08ue1zeVcspI" + }, + "source": [ + "### Compilation\n", + "Now we have the optimized eager TorchRec inference model. The next step is to ensure that this model is loadable in C++, as currently it is only runnable in a Python runtime.\n", + "\n", + "The recommended method of compilation at Meta is two fold: [torch.fx tracing](https://pytorch.org/docs/stable/fx.html) (generate intermediate representation of model) and converting the result to TorchScript, where TorchScript is C++ compatible." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "bdab6e95-3a71-4c3d-b188-115873f1f5d5", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000159118, + "executionStopTime": 1726000159308, + "serverExecutionDuration": 28.788283467293, + "requestMsgId": "bdab6e95-3a71-4c3d-b188-115873f1f5d5", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "491668137118498" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SRzo1jljcspI", + "outputId": "e0f94cf0-c5cf-4480-b0d9-d5dd2abb459b" + }, + "source": [ + "from torchrec.fx import Tracer\n", + "\n", + "\n", + "tracer = Tracer(leaf_modules=[\"IntNBitTableBatchedEmbeddingBagsCodegen\"])\n", + "\n", + "graph = tracer.trace(sharded_qebc)\n", + "gm = torch.fx.GraphModule(sharded_qebc, graph)\n", + "\n", + "print(\"Graph Module Created!\")" + ], + "execution_count": 46, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Graph Module Created!\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "909178d6-4dae-45da-9c39-6827019f53a3", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000159312, + "executionStopTime": 1726000159490, + "serverExecutionDuration": 2.2248737514019, + "requestMsgId": "909178d6-4dae-45da-9c39-6827019f53a3", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1555501808508272" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NsgzbUdHcspI", + "outputId": "c5b67630-19c7-46df-c0ab-216d24309603" + }, + "source": [ + "print(gm.code)" + ], + "execution_count": 47, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embeddingbag_flatten_feature_lengths\")\n", + "torch.fx._symbolic_trace.wrap(\"torchrec_fx_utils__fx_marker\")\n", + "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embedding_kernel__unwrap_kjt\")\n", + "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference\")\n", + "\n", + "def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):\n", + " flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None\n", + " _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths)\n", + " split = flatten_feature_lengths.split([2])\n", + " getitem = split[0]; split = None\n", + " to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None\n", + " _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = None\n", + " _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None\n", + " getitem_1 = _unwrap_kjt[0]\n", + " getitem_2 = _unwrap_kjt[1]\n", + " getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = None\n", + " _tensor_constant0 = self._tensor_constant0\n", + " _tensor_constant1 = self._tensor_constant1\n", + " bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None); _tensor_constant0 = _tensor_constant1 = None\n", + " _tensor_constant2 = self._tensor_constant2\n", + " _tensor_constant3 = self._tensor_constant3\n", + " _tensor_constant4 = self._tensor_constant4\n", + " _tensor_constant5 = self._tensor_constant5\n", + " _tensor_constant6 = self._tensor_constant6\n", + " _tensor_constant7 = self._tensor_constant7\n", + " _tensor_constant8 = self._tensor_constant8\n", + " _tensor_constant9 = self._tensor_constant9\n", + " int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None\n", + " embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None\n", + " to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None\n", + " keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None\n", + " return keyed_tensor\n", + " \n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000159494, + "executionStopTime": 1726000160206, + "serverExecutionDuration": 540.64276814461, + "requestMsgId": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "978016470760577" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CjjJLc6pcspI", + "outputId": "be3a9486-e4b5-43f6-aed0-711e827a0040" + }, + "source": [ + "scripted_gm = torch.jit.script(gm)\n", + "print(\"Scripted Graph Module Created!\")" + ], + "execution_count": 48, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scripted Graph Module Created!\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9eb089f1-2771-419d-a48e-3b7330c0a1e4", + "showInput": true, + "customInput": null, + "language": "python", + "executionStartTime": 1726000160212, + "executionStopTime": 1726000160395, + "serverExecutionDuration": 2.8529539704323, + "requestMsgId": "9eb089f1-2771-419d-a48e-3b7330c0a1e4", + "outputsInitialized": true, + "customOutput": null, + "output": { + "id": "1020643789855657" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BWKPRaI3cspI", + "outputId": "273181a2-7c91-4167-e814-4a07b51c6b10" + }, + "source": [ + "print(scripted_gm.code)" + ], + "execution_count": 49, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "def forward(self,\n", + " kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:\n", + " _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths\n", + " _1 = __torch__.torchrec.fx.utils._fx_marker\n", + " _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt\n", + " _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference\n", + " flatten_feature_lengths = _0(kjt, )\n", + " _fx_marker = _1(\"KJT_ONE_TO_ALL_FORWARD_BEGIN\", flatten_feature_lengths, )\n", + " split = (flatten_feature_lengths).split([2], )\n", + " getitem = split[0]\n", + " to = (getitem).to(torch.device(\"cuda\", 0), True, None, )\n", + " _fx_marker_1 = _1(\"KJT_ONE_TO_ALL_FORWARD_END\", flatten_feature_lengths, )\n", + " _unwrap_kjt = _2(to, )\n", + " getitem_1 = (_unwrap_kjt)[0]\n", + " getitem_2 = (_unwrap_kjt)[1]\n", + " _tensor_constant0 = self._tensor_constant0\n", + " _tensor_constant1 = self._tensor_constant1\n", + " ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)\n", + " _tensor_constant2 = self._tensor_constant2\n", + " _tensor_constant3 = self._tensor_constant3\n", + " _tensor_constant4 = self._tensor_constant4\n", + " _tensor_constant5 = self._tensor_constant5\n", + " _tensor_constant6 = self._tensor_constant6\n", + " _tensor_constant7 = self._tensor_constant7\n", + " _tensor_constant8 = self._tensor_constant8\n", + " _tensor_constant9 = self._tensor_constant9\n", + " int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)\n", + " _4 = [int_nbit_split_embedding_codegen_lookup_function]\n", + " embeddings_cat_empty_rank_handle_inference = _3(_4, 1, \"cuda:0\", 6, )\n", + " to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device(\"cpu\"))\n", + " _5 = [\"product\", \"user\"]\n", + " _6 = [64, 64]\n", + " keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)\n", + " _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )\n", + " return keyed_tensor\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "9a1dda10-b9cf-4d9f-b068-51ae3ce3ffc1", + "showInput": false, + "customInput": null, + "language": "markdown", + "outputsInitialized": false, + "id": "DQiGRYOgcspI" + }, + "source": [ + "## Congrats!\n", + "\n", + "You have now gone from training a distributed RecSys model all the way to making it inference ready. https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a full example of how to load a TorchRec TorchScript model into C++ for inference." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ebXfh7oW9fHH", + "originalKey": "4ca6a593-9ac9-4e2f-bc9a-8c8a1887ad41", + "outputsInitialized": false, + "language": "markdown", + "showInput": false + }, + "source": [ + "## More resources\n", + "For more information, please see our [dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) example, which includes multinode training on the criteo terabyte dataset, using Meta’s [DLRM](https://arxiv.org/abs/1906.00091)." + ] + } + ] +} diff --git a/Torchrec_Introduction.ipynb b/Torchrec_Introduction.ipynb index 5da052875..970fbf8cc 100644 --- a/Torchrec_Introduction.ipynb +++ b/Torchrec_Introduction.ipynb @@ -19,44 +19,11 @@ "source": [ "## **Installation**\n", "Requirements:\n", - "- python >= 3.7\n", + "- python >= 3.9\n", "\n", "We highly recommend CUDA when using TorchRec. If using CUDA:\n", - "- cuda >= 11.0\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BB2K68OYUJ_t" - }, - "outputs": [], - "source": [ - "# install conda to make installying pytorch with cudatoolkit 11.3 easier. \n", - "!wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\n", - "!chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh\n", - "!bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sFYvP95xaAER" - }, - "outputs": [], - "source": [ - "# install pytorch with cudatoolkit 11.6\n", - "!conda install pytorch pytorch-cuda=11.6 -c pytorch-nightly -c nvidia -y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7iY7Uv11mJYK" - }, - "source": [ + "- cuda >= 11.8\n", + "\n", "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run " ] }, @@ -64,53 +31,14 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "tUnIw-ZREQJy" - }, - "outputs": [], - "source": [ - "# install torchrec\n", - "!pip3 install torchrec-nightly" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b6EHgotRXFQh" - }, - "source": [ - "The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the colab runtime**. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_P45pDteRcWj" - }, - "outputs": [], - "source": [ - "!cp /usr/local/lib/lib* /usr/lib/" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n5_X2WOAYG3c" - }, - "source": [ - "\\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8cktNrh8R9rC" + "id": "sFYvP95xaAER" }, "outputs": [], "source": [ - "import sys\n", - "sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages']" + "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n", + "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/nightly/cu121\n", + "!pip3 install torchmetrics==1.0.3\n", + "!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121" ] }, { @@ -236,11 +164,11 @@ "metadata": {}, "outputs": [], "source": [ - "from torchrec.optim.apply_overlapped_optimizer import apply_overlapped_optimizer\n", + "from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\n", "\n", - "apply_overlapped_optimizer(\n", - " ebc.parameters(),\n", - " optimizer_type=torch.optim.SGD,\n", + "apply_optimizer_in_backward(\n", + " optimizer_class=torch.optim.SGD,\n", + " params=ebc.parameters(),\n", " optimizer_kwargs={\"lr\": 0.02},\n", ")" ] @@ -625,35 +553,46 @@ } ], "metadata": { - "accelerator": "GPU", - "colab": { - "background_execution": "on", - "collapsed_sections": [], - "machine_shape": "hm", - "name": "Torchrec Introduction.ipynb", - "provenance": [] - }, - "interpreter": { - "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363" - }, - "kernelspec": { - "display_name": "Python 3.7.13 ('torchrec': conda)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 + "custom": { + "cells": [], + "metadata": { + "accelerator": "GPU", + "colab": { + "background_execution": "on", + "collapsed_sections": [], + "machine_shape": "hm", + "name": "Torchrec Introduction.ipynb", + "provenance": [] + }, + "fileHeader": "", + "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00", + "interpreter": { + "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363" + }, + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3.7.13 ('torchrec': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } + "nbformat": 4, + "nbformat_minor": 0 + }, + "indentAmount": 2 }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 2 } diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp index b0f6a4bde..23001de25 100644 --- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp +++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp @@ -12,7 +12,7 @@ namespace torchrec { void BM_MixedLFULRUStrategy(benchmark::State& state) { size_t num_ext_values = state.range(0); - std::vector ext_values(num_ext_values); + std::vector ext_values(num_ext_values); MixedLFULRUStrategy strategy; for (auto& v : ext_values) { diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp index 1f4183d74..f8e2a0cd2 100644 --- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp +++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp @@ -18,13 +18,13 @@ class RecordIterator { public: RecordIterator(Container::const_iterator begin, Container::const_iterator end) : begin_(begin), end_(end) {} - std::optional> operator()() { + std::optional operator()() { if (begin_ == end_) { return std::nullopt; } - TransformerRecord record{}; - record.global_id_ = next_global_id_++; - record.lxu_record_ = *reinterpret_cast(&(*begin_++)); + record_t record{}; + record.global_id = next_global_id_++; + record.lxu_record = *reinterpret_cast(&(*begin_++)); return record; } @@ -52,8 +52,8 @@ class RandomizeMixedLXUSet { std::uniform_int_distribution time_dist(0, max_time - 1); for (size_t i = 0; i < n; ++i) { MixedLFULRUStrategy::Record record{}; - record.freq_power_ = freq_dist(engine) + min_freq; - record.time_ = time_dist(engine); + record.freq_power = freq_dist(engine) + min_freq; + record.time = time_dist(engine); records_.emplace_back(record); } } @@ -68,8 +68,9 @@ class RandomizeMixedLXUSet { void BM_MixedLFULRUStrategyEvict(benchmark::State& state) { RandomizeMixedLXUSet lxuSet(state.range(0), state.range(1), state.range(2)); + MixedLFULRUStrategy strategy; for (auto _ : state) { - MixedLFULRUStrategy::evict(lxuSet.Iterator(), state.range(3)); + strategy.evict(lxuSet.Iterator(), state.range(3)); } } diff --git a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp index 17faadb67..ae66fbfcc 100644 --- a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp +++ b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp @@ -13,8 +13,7 @@ namespace torchrec { static void BM_NaiveIDTransformer(benchmark::State& state) { - using Tag = int32_t; - NaiveIDTransformer transformer(2e8); + NaiveIDTransformer transformer(2e8); torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong); torch::Tensor cache_ids = torch::empty_like(global_ids); for (auto _ : state) { diff --git a/benchmarks/ebc_benchmarks.py b/benchmarks/ebc_benchmarks.py index 082f05bc0..98a0ffcbd 100644 --- a/benchmarks/ebc_benchmarks.py +++ b/benchmarks/ebc_benchmarks.py @@ -5,13 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import argparse import sys from typing import List, Tuple +# pyre-fixme[21]: Could not find module `ebc_benchmarks_utils`. +import ebc_benchmarks_utils import torch -from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation -from torchrec.github.benchmarks import ebc_benchmarks_utils +from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection @@ -251,5 +254,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace: return parser.parse_args(argv) -if __name__ == "__main__": +def invoke_main() -> None: main(sys.argv[1:]) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/ebc_benchmarks_utils.py b/benchmarks/ebc_benchmarks_utils.py index 74f7f67d8..b15ec2b66 100644 --- a/benchmarks/ebc_benchmarks_utils.py +++ b/benchmarks/ebc_benchmarks_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import time from typing import Dict, List, Optional, Tuple diff --git a/contrib/dynamic_embedding/README.md b/contrib/dynamic_embedding/README.md new file mode 100644 index 000000000..975963f35 --- /dev/null +++ b/contrib/dynamic_embedding/README.md @@ -0,0 +1,80 @@ +# TorchRec Dynamic Embedding + +This folder contains the extension to support dynamic embedding for torchrec. Specifically, this extension enable torchrec to attach an external PS, so that when local GPU embedding is not big enough, we could pull/evict embeddings from/to the PS. + +## Installation + +After install torchrec, please clone the torchrec repo and manually install the dynamic embedding: + +```bash +git clone git@github.com:pytorch/torchrec.git +cd contrib/dynamic_embedding +python setup.py install +``` + +And the dynamic embedding will be installed as a separate package named `torchrec_dynamic_embedding`. + +Notice that for C++20 supports we recommend gcc version higher or equal to 10. Conda users could install the lastest gcc utilities with: + +```bash +conda install gxx_linux-64 +``` + +We incorporate `gtest` for the C++ code and use unittest for the python APIs. The tests make sure that the implementation does not have any precision loss. Please turn on the `TDE_WITH_TESTING` in `setup.py` to run tests. Note that for the python test, one needs to set the environment variable `TDE_MEMORY_IO_PATH` to the path of the compiled `memory_io.so`. + +## Usage + +The dynamic embedding extension has only one api, `tde.wrap`, when wrapping the dataloader and model with it, we will automatically pipeline the data processing and model training. And example of `tde.wrap` is: + +```python +import torchrec_dynamic_embedding as tde + +class Model(nn.Module): + def __init__(self, config1, config2): + super().__init__() + self.emb1 = EmbeddingCollection(tables=config1, device=torch.device("meta")) + self.emb2 = EmbeddingCollection(tables=config2, device=torch.device("meta")) + ... + + def forward(self, kjt1, kjt2): + ... + +m = Model(config1, config2) +m = DistributedModelParallel(m) +dataloader = tde.wrap( + "redis://127.0.0.1:6379/?prefix=model", + dataloader, + m, + # configs of the embedding collections in the model + { "emb1": config1, "emb2": config2 }) + +for label, kjt1, kjt2 in dataloader: + output = m(kjt1, kjt2) + ... +``` + +The internal of `tde.wrap` is in `src/torchrec_dynamic_embedding/dataloader.py`, where we will attach hooks to the embedding tensor as well as creating the dataloader thread for pipelining. + +## Custom PS Extension + +The dynamic embedding extension supports connecting with your PS cluster. To write your own PS extension, you need to create an dynamic library (`*.so`) with these 4 functions and 1 variable: + +```c++ +const char* IO_type = "your-ps"; + +void* IO_Initialize(const char* cfg); + +void IO_Finalize(void* instance); + +void IO_Pull(void* instance, IOPullParameter cfg); + +void IO_Push(void* instance, IOPushParameter cfg); +``` + +And then use the following python API to register it: + +```python +torch.ops.tde.register_io(so_path) +``` + +After that, you could use your own PS extension by passing the corresponding URL into `tde.wrap`, where the protocol name would be the `IO_type` and the string after `"://"` will be passed to `IO_Finalize` (`"type://cfg"`). diff --git a/contrib/dynamic_embedding/setup.py b/contrib/dynamic_embedding/setup.py index b4a6e968d..c5e269200 100644 --- a/contrib/dynamic_embedding/setup.py +++ b/contrib/dynamic_embedding/setup.py @@ -1,7 +1,16 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + import os import sys import torch +from setuptools import find_packages from skbuild import setup @@ -40,7 +49,7 @@ setup( name="torchrec_dynamic_embedding", package_dir={"": "src"}, - packages=["torchrec_dynamic_embedding"], + packages=find_packages("src"), cmake_args=[ "-DCMAKE_BUILD_TYPE=Release", f"-DTDE_TORCH_BASE_DIR={os.path.dirname(torch.__file__)}", diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h index 6d6a328f3..bada56602 100644 --- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h +++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include #include @@ -30,7 +38,7 @@ class CachelineIDTransformerIterator { continue; } TransformerRecord result{}; - result.global_id_ = -record.global_id_not_; + result.global_id_ = ~record.global_id_not_; result.cache_id_ = record.cache_id_; result.lxu_record_ = record.lxu_record_; return result; diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h index 5cf36a3b7..6e28fe170 100644 --- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h +++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include #include @@ -64,38 +72,40 @@ inline bool CachelineIDTransformer< int64_t global_id_not = ~global_id; CacheValue* group_begin = &cache_values_[group_id * group_size_]; int64_t k = 0; + int64_t empty_slot = -1; + int64_t cache_id = -1; for (; k < group_size_; k++, intra_id++) { intra_id %= group_size_; auto& cache_value = group_begin[intra_id]; // tricky but fast :p int64_t xor_value = cache_value.global_id_not_ ^ global_id_not; - if (xor_value > 0) { - continue; - } - int64_t cache_id; if (xor_value == 0) { // found cache_id = cache_value.cache_id_; cache_value.lxu_record_ = update(cache_value.lxu_record_, global_id, cache_id); - } else { // empty slot + break; + } else if (xor_value < 0 && empty_slot < 0) { // empty slot + empty_slot = intra_id; + } + } + if (cache_id < 0) { + if (empty_slot >= 0) { // The transformer is full. if (C10_UNLIKELY(bitmap_.Full())) { return false; } + auto& cache_value = group_begin[empty_slot]; cache_id = bitmap_.NextFreeBit(); cache_value.global_id_not_ = global_id_not; cache_value.cache_id_ = cache_id; cache_value.lxu_record_ = update(std::nullopt, global_id, cache_id); fetch(global_id, cache_id); + } else { + return false; } - cache_ids[i] = cache_id; - break; - } - - if (k == group_size_) { - return false; } + cache_ids[i] = cache_id; } return true; } @@ -115,14 +125,14 @@ inline void CachelineIDTransformer< for (const int64_t global_id : global_ids) { auto [group_id, intra_id] = FindGroupIndex(global_id); + int64_t global_id_not = ~global_id; for (int64_t k = 0; k < group_size_; k++) { int64_t offset = group_id * group_size_ + (intra_id + k) % group_size_; auto& cache_value = cache_values_[offset]; // tricky but fast :p - int64_t global_id_not = ~global_id; int64_t xor_value = global_id_not ^ cache_value.global_id_not_; if (xor_value < 0) { // not exist - break; + continue; } else if (xor_value == 0) { // found slot bitmap_.FreeBit(cache_value.cache_id_); cache_value.global_id_not_ = 0; diff --git a/contrib/dynamic_embedding/src/tde/ps.cpp b/contrib/dynamic_embedding/src/tde/ps.cpp index 4e3bdecb9..e9fcb3c25 100644 --- a/contrib/dynamic_embedding/src/tde/ps.cpp +++ b/contrib/dynamic_embedding/src/tde/ps.cpp @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #include "tde/ps.h" #include "tde/details/io.h" @@ -17,9 +25,13 @@ c10::intrusive_ptr PS::Fetch( if (cache_ids_to_fetch_or_evict_.empty()) { return c10::make_intrusive(time, c10::intrusive_ptr()); } - fetch_notifications_.emplace_back(time, c10::make_intrusive()); - c10::intrusive_ptr notification = - fetch_notifications_.back().second; + c10::intrusive_ptr notification; + { + std::unique_lock lock_fetch(fetch_notifications_mutex_); + fetch_notifications_.emplace_back( + time, c10::make_intrusive()); + notification = fetch_notifications_.back().second; + } uint32_t num_os_ids = os_ids_.size(); io_.Pull( table_name_, @@ -97,9 +109,7 @@ void PS::Evict(torch::Tensor ids_to_evict) { notification.Done(); // The shared data for all chunks. std::vector offsets; - offsets.reserve(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1); - std::vector data( - num_ids_per_chunk_ * num_os_ids * col_ids.size() * col_size_); + offsets.resize(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1); for (uint32_t i = 0; i < num_ids_to_fetch; i += num_ids_per_chunk_) { uint32_t num_ids_in_chunk = std::min( @@ -107,22 +117,22 @@ void PS::Evict(torch::Tensor ids_to_evict) { uint32_t data_size = num_ids_in_chunk * num_os_ids * col_ids.size(); uint32_t offsets_size = num_ids_in_chunk * num_os_ids * col_ids.size() + 1; - offsets.clear(); - offsets.emplace_back(0); + std::vector all_tensors; for (uint32_t j = i; j < i + num_ids_in_chunk; ++j) { int64_t cache_id = cache_ids_to_fetch_or_evict_[j]; std::vector tensors = GetTensorViews(cache_id); - for (uint32_t k : os_ids_) { - // this cause 2 copy. is this avoidable? - torch::Tensor tensor = tensors[k].cpu(); - // need to change this when considering col - memcpy( - reinterpret_cast(data.data()) + offsets.back(), - tensor.data_ptr(), - tensor.numel() * tensor.element_size()); - offsets.emplace_back( - offsets.back() + tensor.numel() * tensor.element_size()); - } + all_tensors.insert(all_tensors.end(), tensors.begin(), tensors.end()); + } + torch::Tensor data = torch::cat(all_tensors, 0).cpu(); + TORCH_CHECK(data.numel() == data_size * col_size_); + + // to prevent the original data from being prematurely recycled + auto data_shared_ptr = std::make_shared(data); + + offsets[0] = 0; + for (uint32_t j = 0; j < all_tensors.size(); ++j) { + offsets[j + 1] = + offsets[j] + all_tensors[j].numel() * all_tensors[j].element_size(); } // waiting for the Push of last chunk finishes. notification.Wait(); @@ -133,21 +143,30 @@ void PS::Evict(torch::Tensor ids_to_evict) { col_ids, os_ids_, tcb::span{ - reinterpret_cast(data.data()), data_size * sizeof(float)}, + reinterpret_cast(data_shared_ptr->data_ptr()), + data_size * sizeof(float)}, tcb::span{offsets.data(), offsets_size}, - [¬ification] { notification.Done(); }); + [¬ification, data_shared_ptr] { notification.Done(); }); } notification.Wait(); } void PS::SyncFetch(int64_t time) { - while (!fetch_notifications_.empty()) { - auto& [t, notification] = fetch_notifications_.front(); - if (t != time && time >= 0) { + std::unique_lock lock( + fetch_notifications_mutex_, std::defer_lock); + + while (true) { + lock.lock(); + if (fetch_notifications_.empty() || + fetch_notifications_.front().first != time && time >= 0) { + lock.unlock(); break; } - notification->Wait(); + auto notification = fetch_notifications_.front().second; fetch_notifications_.pop_front(); + lock.unlock(); + + notification->Wait(); } } diff --git a/contrib/dynamic_embedding/src/tde/ps.h b/contrib/dynamic_embedding/src/tde/ps.h index aed24fdc6..cad040f4c 100644 --- a/contrib/dynamic_embedding/src/tde/ps.h +++ b/contrib/dynamic_embedding/src/tde/ps.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include #include @@ -99,6 +107,7 @@ class PS : public torch::CustomClassHolder { void Filter(const torch::Tensor& tensor); std::mutex mu_; + std::mutex fetch_notifications_mutex_; std::string table_name_; c10::intrusive_ptr shards_; int64_t col_size_; diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py index 1a7645fd3..96911a449 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py @@ -1,3 +1,10 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import queue import threading from typing import Dict, List, Union @@ -34,6 +41,9 @@ def transform_loop(dataloader, transform_fn, out_queue, done_event): # save memory del transformed_data + if not done_event.is_set(): + done_event.set() + class DataLoaderIter: def __init__(self, dataloader, transform_fn, num_prefetch=0): @@ -55,6 +65,8 @@ def __del__(self): self._done_event.set() def _get_data(self): + if self._done_event.is_set(): + raise StopIteration if not self._transform_thread.is_alive(): raise RuntimeError("Transform thread exited unexpectedly") data, handles = self._data_queue.get() diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py index cdf9a93a9..594213334 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py @@ -1 +1,15 @@ -from .comm import default_group, gather_kjts +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from .comm import ( + broadcast_ids_to_evict, + broadcast_transform_result, + gather_global_ids, + scatter_cache_ids, +) diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py index b6da3be0a..f15f3ce18 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py @@ -1,127 +1,98 @@ -from dataclasses import dataclass +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Optional import torch import torch.distributed as dist -from torchrec import KeyedJaggedTensor - __all__ = [] -_group = [None] - - -@dataclass -class GatherResult: - values_list: List[List[torch.Tensor]] - offset_per_key_list: List[List[torch.Tensor]] - - -def default_group(group): - if group is not None: - return group - - if _group[0] is None: - _group[0] = dist.new_group(backend="gloo") - - return _group[0] - -def gather_tensor_list( - tensors: List[torch.Tensor], numel_lists: List[List[int]], dst=0, group=None -) -> Optional[List[List[torch.Tensor]]]: - rank = dist.get_rank() +def gather_global_ids(global_ids: List[torch.Tensor], group): world_size = dist.get_world_size() - if dst == rank: - if len(numel_lists) != world_size: - raise ValueError("dst rank should know size of tensors on all ranks") - else: - if len(numel_lists) != 1: - raise ValueError("non dst rank should pass its own tensor sizes") - - group = default_group(group) - dtype = tensors[0].dtype - device = tensors[0].device + rank = dist.get_rank() - concated_tensor = torch.cat(tensors) + concat_global_ids = torch.cat(global_ids) - if rank == dst: - # gather can only accept same-size tensors. - max_numel = max(sum(numel_list) for numel_list in numel_lists) - concat_results = [ - torch.empty(max_numel, dtype=dtype, device=device) - for _ in range(world_size) - ] - dist.gather( - concated_tensor, - gather_list=concat_results, - dst=dst, - group=group, - async_op=False, - ) + concat_numel = torch.tensor(concat_global_ids.numel(), dtype=torch.int64) + concat_numel_list = [torch.tensor(0, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather(concat_numel_list, concat_numel, group=group, async_op=False) - results = [] - for i in range(world_size): - splited_tensors = [] - offset = 0 - for numel in numel_lists[i]: - splited_tensors.append(concat_results[i][offset : offset + numel]) - offset += numel - results.append(splited_tensors) + max_numel = max(concat_numel_list) + concat_global_ids.resize_(max_numel) - return results + if rank == 0: + concat_global_ids_list = [ + torch.empty_like(concat_global_ids) for _ in range(world_size) + ] + dist.gather(concat_global_ids, concat_global_ids_list, 0, group, async_op=False) + return [ + concat_global_ids_list[i][: concat_numel_list[i]] for i in range(world_size) + ], concat_numel_list else: - dist.gather( - concated_tensor, gather_list=None, dst=dst, group=group, async_op=False - ) - - return None + dist.gather(concat_global_ids, None, 0, group, async_op=False) + return None, concat_numel_list -def gather_kjts(kjts: List[KeyedJaggedTensor], dst=0, group=None) -> GatherResult: +def scatter_cache_ids( + cache_ids_list: Optional[List[torch.Tensor]], concat_numel_list: List[int], group +): world_size = dist.get_world_size() - if world_size == 1: - return GatherResult( - values_list=[[kjt.values()] for kjt in kjts], - offset_per_key_list=[[kjt.offset_per_key()] for kjt in kjts], - ) - rank = dist.get_rank() - group = default_group(group) - offset_per_key_list = [torch.tensor(kjt.offset_per_key()) for kjt in kjts] - values_list = [kjt.values() for kjt in kjts] + max_numel = max(concat_numel_list) - offset_numel_list = [tensor.numel() for tensor in offset_per_key_list] - values_numel_list = [tensor.numel() for tensor in values_list] - - if rank == dst: - global_offset_numel_list = [offset_numel_list] * world_size - offset_results = gather_tensor_list( - offset_per_key_list, - numel_lists=global_offset_numel_list, - dst=dst, - group=group, - ) - - global_values_numel_list = [ - [offset[-1].item() for offset in offsets] for offsets in offset_results + concat_cache_ids = torch.empty(max_numel, dtype=torch.int64) + if rank == 0: + concat_cache_ids_list = [concat_cache_ids] + [ + cache_ids.resize_(max_numel) + for cache_ids in cache_ids_list[-world_size + 1 :] ] - - values_result = gather_tensor_list( - values_list, numel_lists=global_values_numel_list, dst=dst, group=group - ) - - return GatherResult( - values_list=values_result, offset_per_key_list=offset_results - ) + assert len(concat_cache_ids_list) == world_size + dist.scatter(concat_cache_ids, concat_cache_ids_list, group=group) else: - gather_tensor_list( - offset_per_key_list, numel_lists=[offset_numel_list], dst=dst, group=group + dist.scatter(concat_cache_ids, None, group=group) + offset = 0 + for cache_ids in cache_ids_list: + cache_ids[:] = concat_cache_ids[offset : offset + cache_ids.numel()] + offset += cache_ids.numel() + + +def broadcast_transform_result( + success: bool, ids_to_fetch: Optional[torch.Tensor], group +): + if dist.get_rank() == 0: + success_and_numel = torch.tensor( + [1 if success else 0, ids_to_fetch.numel()], dtype=torch.int64 ) + dist.broadcast(success_and_numel, src=0, group=group) + else: + success_and_numel = torch.tensor([0, 0], dtype=torch.int64) + dist.broadcast(success_and_numel, src=0, group=group) + success, numel = success_and_numel.tolist() + success = success != 0 + ids_to_fetch = torch.empty((numel // 2, 2), dtype=torch.int64) + + if ids_to_fetch.numel() > 0: + dist.broadcast(ids_to_fetch, src=0, group=group) + return success, ids_to_fetch - gather_tensor_list( - values_list, numel_lists=[values_numel_list], dst=dst, group=group - ) - return None +def broadcast_ids_to_evict(ids, group): + if dist.get_rank() == 0: + numel = torch.tensor(ids.numel(), dtype=torch.int64) + dist.broadcast(numel, src=0, group=group) + else: + numel = torch.tensor(0, dtype=torch.int64) + dist.broadcast(numel, src=0, group=group) + numel = numel.item() + ids = torch.empty((numel // 2, 2), dtype=torch.int64) + + if numel > 0: + dist.broadcast(ids, src=0, group=group) + return ids diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py index 6d7169183..735412d35 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py @@ -1,3 +1,10 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import json import os @@ -33,9 +40,10 @@ def transform(self, global_ids: TensorList, cache_ids: TensorList, time: int): """ Transform `global_ids` and store the results in `cache_ids`. """ - return self._transformer.transform( + result = self._transformer.transform( global_ids.tensor_list, cache_ids.tensor_list, time ) + return result.success, result.ids_to_fetch def evict(self, num_to_evict): """ diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py index ced6780df..1132d6be6 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py @@ -1,8 +1,24 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + from typing import List, Tuple, Union import torch +import torch.distributed as dist from torchrec import EmbeddingBagConfig, EmbeddingConfig, KeyedJaggedTensor +from .distributed import ( + broadcast_ids_to_evict, + broadcast_transform_result, + gather_global_ids, + scatter_cache_ids, +) from .id_transformer import IDTransformer, TensorList from .ps import PSCollection @@ -27,7 +43,7 @@ def __init__( tables: list of `Embedding(Bag)Config` or `EmbeddingBagConfig` one passed to `Embedding(Bag)Collection`. eviction_config: config of the eviction strategy for IDTransformers. - transformer_config: config of the transform strategy for IDTransformers. + transform_config: config of the transform strategy for IDTransformers. ps_collection: `PSCollection` of the collection, if `None`, won't do eviction or fetch. By default, IDTransformerCollection will evict half the ids when full. """ @@ -46,19 +62,61 @@ def __init__( for feature_name in config.feature_names: if feature_name in feature_names: raise ValueError(f"Shared feature not allowed yet.") - self._transformers.append( - IDTransformer( + # only rank 0 will have the id transformer + # and other ranks will gather their to rank 0. + if dist.get_rank() == 0: + transformer = IDTransformer( num_embedding=config.num_embeddings, eviction_config=eviction_config, transform_config=transform_config, ) - ) + else: + transformer = None + self._transformers.append(transformer) self._feature_names: List[List[str]] = [ config.feature_names for config in tables ] self._ever_evicted = False self._time = 0 + if dist.get_world_size() > 1: + self._pg = dist.new_group(backend="gloo") + self._stream = torch.cuda.Stream() + + def _transform( + self, transformer, global_ids: List[torch.Tensor], cache_ids: List[torch.Tensor] + ): + with torch.cuda.stream(self._stream): + total_numel = sum([tensor.numel() for tensor in global_ids]) + if total_numel > 1e6: + all_tensor = torch.cat(global_ids).to("cuda:0") + unique_all_tensor, index = torch.unique(all_tensor, return_inverse=True) + unique_all_tensor = unique_all_tensor.to("cpu") + all_cache = torch.empty_like(unique_all_tensor) + success, ids_to_fetch = transformer.transform( + TensorList([unique_all_tensor]), + TensorList([all_cache]), + self._time, + ) + del all_tensor + all_tensor = torch.take(all_cache.to("cuda:0"), index) + offset = 0 + for tensor in cache_ids: + numel = tensor.numel() + tensor.copy_(all_tensor[offset : offset + numel]) + offset += numel + assert ( + total_numel == offset + ), f"total_numel not equal offset, {total_numel} vs {offset}" + else: + # broadcast result + success, ids_to_fetch = transformer.transform( + TensorList(global_ids), + TensorList(cache_ids), + self._time, + ) + return success, ids_to_fetch + def transform( self, global_features: KeyedJaggedTensor ) -> Tuple[KeyedJaggedTensor, List[torch.classes.tde.FetchHandle]]: @@ -92,47 +150,128 @@ def transform( for idx in feature_indices ] - result = transformer.transform( - TensorList(global_ids), TensorList(cache_ids), self._time - ) - if self._ps_collection is not None: - table_name = self._table_names[i] - ps = self._ps_collection[table_name] - if result.ids_to_fetch.numel() > 0: - handle = ps.fetch( - result.ids_to_fetch, - self._time, - self._ever_evicted, - self._configs[i].get_weight_init_min(), - self._configs[i].get_weight_init_max(), - ) - fetch_handles.append(handle) - if not result.success: - # TODO(zilinzhu): make this configurable - ids_to_evict = transformer.evict(transformer._num_embedding // 2) - ps.evict(ids_to_evict) - self._ever_evicted = True - - # retry after eviction. - result = transformer.transform( - TensorList(global_ids), TensorList(cache_ids), self._time + if dist.get_world_size() > 1: + concat_global_ids, concat_numel_list = gather_global_ids( + global_ids, self._pg + ) + if dist.get_rank() == 0: + global_ids = global_ids + concat_global_ids[1:] + cache_ids = cache_ids + [ + torch.empty_like(tensor) for tensor in concat_global_ids[1:] + ] + + success, ids_to_fetch = self._transform( + transformer, global_ids, cache_ids ) - if not result.success: - raise RuntimeError( - "Failed to transform global ids after eviction. " - f"Maybe the num_embedding of table {table_name} is too small?" + else: + success, ids_to_fetch = True, None + success, ids_to_fetch = broadcast_transform_result( + success, ids_to_fetch, self._pg + ) + + if self._ps_collection is not None: + table_name = self._table_names[i] + ps = self._ps_collection[table_name] + if ids_to_fetch.numel() > 0: + handle = ps.fetch( + ids_to_fetch, + self._time, + self._ever_evicted, + self._configs[i].get_weight_init_min(), + self._configs[i].get_weight_init_max(), ) - if result.ids_to_fetch is not None: - fetch_handles.append( - ps.fetch( - result.ids_to_fetch, + fetch_handles.append(handle) + if not success: + # TODO(zilinzhu): make this configurable + # broadcast ids_to_evict + if dist.get_rank() == 0: + ids_to_evict = transformer.evict( + transformer._num_embedding // 2 + ) + else: + ids_to_evict = None + ids_to_evict = broadcast_ids_to_evict(ids_to_evict, self._pg) + + ps.evict(ids_to_evict) + self._ever_evicted = True + + # retry after eviction. + # broadcast result + if dist.get_rank() == 0: + success, ids_to_fetch = transformer.transform( + TensorList(global_ids), + TensorList(cache_ids), self._time, - self._ever_evicted, - self._configs[i].get_weight_init_min(), - self._configs[i].get_weight_init_max(), ) + else: + success, ids_to_fetch = True, None + success, ids_to_fetch = broadcast_transform_result( + success, ids_to_fetch, self._pg ) + if not success: + raise RuntimeError( + "Failed to transform global ids after eviction. " + f"Maybe the num_embedding of table {table_name} is too small?" + ) + if ids_to_fetch.numel() > 0: + fetch_handles.append( + ps.fetch( + ids_to_fetch, + self._time, + self._ever_evicted, + self._configs[i].get_weight_init_min(), + self._configs[i].get_weight_init_max(), + ) + ) + + scatter_cache_ids(cache_ids, concat_numel_list, self._pg) + else: + success, ids_to_fetch = self._transform( + transformer, global_ids, cache_ids + ) + if self._ps_collection is not None: + table_name = self._table_names[i] + ps = self._ps_collection[table_name] + if ids_to_fetch.numel() > 0: + handle = ps.fetch( + ids_to_fetch, + self._time, + self._ever_evicted, + self._configs[i].get_weight_init_min(), + self._configs[i].get_weight_init_max(), + ) + fetch_handles.append(handle) + if not success: + # TODO(zilinzhu): make this configurable + ids_to_evict = transformer.evict( + transformer._num_embedding // 2 + ) + ps.evict(ids_to_evict) + self._ever_evicted = True + + # retry after eviction. + success, ids_to_fetch = transformer.transform( + TensorList(global_ids), + TensorList(cache_ids), + self._time, + ) + if not success: + raise RuntimeError( + "Failed to transform global ids after eviction. " + f"Maybe the num_embedding of table {table_name} is too small?" + ) + if ids_to_fetch is not None: + fetch_handles.append( + ps.fetch( + ids_to_fetch, + self._time, + self._ever_evicted, + self._configs[i].get_weight_init_min(), + self._configs[i].get_weight_init_max(), + ) + ) + cache_values = KeyedJaggedTensor( keys=global_features.keys(), values=cache_values, @@ -147,5 +286,18 @@ def save(self): return for i, transformer in enumerate(self._transformers): table_name = self._table_names[i] - ids = transformer.save() + if dist.get_world_size() > 1: + if dist.get_rank() == 0: + ids = transformer.save() + numel = torch.tensor(ids.numel()) + dist.broadcast(numel, src=0, group=self._pg) + dist.broadcast(ids, src=0, group=self._pg) + else: + numel = torch.tensor(0) + dist.broadcast(numel, src=0, group=self._pg) + ids = torch.empty((numel // 2, 2), dtype=torch.int64) + dist.broadcast(ids, src=0, group=self._pg) + else: + ids = transformer.save() + self._ps_collection[table_name].evict(ids) diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py index c26b2ad4f..9c888a15f 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py @@ -57,8 +57,8 @@ def __init__( configs or embeddingbag configs. The plan of `module` should contain the module path in `configs_dict`. eviction_config: configuration for eviction policy. Default is `{"type": "mixed_lru_lfu"}` - transformer_config: configuration for the transformer. Default is `{"type": "naive"}` - parallel: Whether the IDTransformerCollections will run paralell. When set to True, + transform_config: configuration for the transformer. Default is `{"type": "naive"}` + parallel: Whether the IDTransformerCollections will run parallel. When set to True, IDTransformerGroup will start a thread for each IDTransformerCollection. Example: @@ -166,12 +166,6 @@ def __contains__(self, path): """ return path in self._id_transformer_collections - def __contains__(self, path): - """ - Check if there is transformer for the path. - """ - return path in self._id_transformer_collections - def __del__(self): """ Stop the parallel threads diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py index 690e89b97..763f3e8ef 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py @@ -1,10 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import os from typing import Callable, Dict, List, Tuple, Union import torch -import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor -from torchrec.distributed.model_parallel import DistributedModelParallel as DMP from torchrec.distributed.types import ParameterSharding from .tensor_list import TensorList @@ -54,7 +59,6 @@ def __init__( # This assumes all shard have the same column size. col_size = shard.tensor.shape[1] elif isinstance(tensors[0], torch.Tensor): - tensors shards.append( 0, 0, diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py index fc11cce9a..591eda4c3 100644 --- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py +++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py @@ -1,4 +1,9 @@ -import json +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import os from typing import List diff --git a/contrib/dynamic_embedding/tests/test_integral_precision.py b/contrib/dynamic_embedding/tests/test_integral_precision.py index 842a2b72d..965bb092c 100644 --- a/contrib/dynamic_embedding/tests/test_integral_precision.py +++ b/contrib/dynamic_embedding/tests/test_integral_precision.py @@ -1,3 +1,10 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import unittest import torch @@ -14,6 +21,7 @@ from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter from torchrec_dynamic_embedding.id_transformer_group import IDTransformerGroup from utils import init_dist, register_memory_io @@ -93,7 +101,7 @@ def get_dmp(model): model = DMP(module=model, device=device, plan=plan, sharders=sharders) dense_optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), + dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.Adam(params, lr=1e-1), ) optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) diff --git a/contrib/dynamic_embedding/tests/test_ps_collection.py b/contrib/dynamic_embedding/tests/test_ps_collection.py index a7fffe1e0..d293749e1 100644 --- a/contrib/dynamic_embedding/tests/test_ps_collection.py +++ b/contrib/dynamic_embedding/tests/test_ps_collection.py @@ -1,4 +1,9 @@ -import os +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import unittest import torch diff --git a/contrib/dynamic_embedding/tools/before_linux_build.sh b/contrib/dynamic_embedding/tools/before_linux_build.sh index 7cf1d154f..527ee2128 100755 --- a/contrib/dynamic_embedding/tools/before_linux_build.sh +++ b/contrib/dynamic_embedding/tools/before_linux_build.sh @@ -1,13 +1,20 @@ #!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + set -xe distro=rhel7 arch=x86_64 -CUDA_VERSION="${CUDA_VERSION:-11.6}" +CUDA_VERSION="${CUDA_VERSION:-11.8}" CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}') CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}') +yum install -y yum-utils yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo yum install -y \ cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \ diff --git a/contrib/dynamic_embedding/tools/build_wheels.sh b/contrib/dynamic_embedding/tools/build_wheels.sh index 51cd42500..3647547de 100755 --- a/contrib/dynamic_embedding/tools/build_wheels.sh +++ b/contrib/dynamic_embedding/tools/build_wheels.sh @@ -6,6 +6,8 @@ export CIBW_BEFORE_BUILD="tools/before_linux_build.sh" # all kinds of CPython. export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"} +export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"} + # Do not auditwheels since tde uses torch's shared libraries. export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}" diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..fb00038dc --- /dev/null +++ b/docs/README.md @@ -0,0 +1,23 @@ +Docs +========== + + +## Building the docs + +To build and preview the docs run the following commands: + +```bash +cd docs +pip3 install -r requirements.txt +make html +python3 -m http.server 8082 --bind :: +``` + +Now you should be able to view the docs in your browser at the link provided in your terminal. + +To reload the preview after making changes, rerun: + +```bash +make html +python3 -m http.server 8082 --bind :: +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 851694777..96d50ad98 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,7 @@ -sphinx +sphinx==5.0.0 +pyre-extensions +sphinx-design +sphinx_copybutton # torch # PyTorch Theme -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css new file mode 100644 index 000000000..55fee00b4 --- /dev/null +++ b/docs/source/_static/css/custom.css @@ -0,0 +1,114 @@ +/** +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +/* sphinx-design styles for cards/tabs */ + +:root { + --sd-color-info: #ee4c2c; + --sd-color-info-highlight: #ee4c2c; + --sd-color-primary: #6c6c6d; + --sd-color-primary-highligt: #f3f4f7; + --sd-color-card-border-hover: #ee4c2c; + --sd-color-card-border: #f3f4f7; + --sd-color-card-background: #fff; + --sd-color-card-text: inherit; + --sd-color-card-header: transparent; + --sd-color-card-footer: transparent; + --sd-color-tabs-label-active: #ee4c2c; + --sd-color-tabs-label-hover: #ee4c2c; + --sd-color-tabs-label-inactive: #6c6c6d; + --sd-color-tabs-underline-active: #ee4c2c; + --sd-color-tabs-underline-hover: #fabdbd; + --sd-color-tabs-underline-inactive: transparent; + --sd-color-tabs-overline: rgb(222, 222, 222); + --sd-color-tabs-underline: rgb(222, 222, 222); +} + +.sd-text-info { + color: #ee4c2c; +} + +.sd-card-img-top { + background: #ee4c2c; + height: 5px !important; +} + +.sd-card { + position: relative; + background-color: #fff; + opacity: 1.0; + border-radius: 0px; + width: 30%; + border: none; + padding-bottom: 0px; +} + +.sd-card-img { + opacity: 0.5; + width: 200px; + padding: 0px; +} + +.sd-card-img:hover { + opacity: 1.0; + background-color: #f3f4f7; +} + + +.sd-card:after { + display: block; + opacity: 1; + content: ''; + border-bottom: solid 1px #ee4c2c; + background-color: #fff; + transform: scaleX(0); + transition: transform .250s ease-in-out; + transform-origin: 0% 50%; +} + +.sd-card:hover { + background-color: #fff; + opacity: 1; + border-top: 1px solid #f3f4f7; + border-left: 1px solid #f3f4f7; + border-right: 1px solid #f3f4f7; +} + +.sd-card:hover:after { + transform: scaleX(1); +} + +.card-prerequisites:hover { + transition: none; + border: none; +} + +.card-prerequisites:hover:after { + transition: none; + transform: none; +} + +.card-prerequisites:after { + display: block; + content: ''; + border-bottom: none; + background-color: #fff; + transform: none; + transition: none; + transform-origin: none; +} + +details.sd-dropdown { + font-weight: 300; + width: auto; +} + +.center-content { + display: flex; + justify-content: center; +} diff --git a/docs/source/_static/img/card-background.svg b/docs/source/_static/img/card-background.svg new file mode 100644 index 000000000..d97193223 --- /dev/null +++ b/docs/source/_static/img/card-background.svg @@ -0,0 +1,13 @@ + + + + + + + diff --git a/docs/source/_static/img/full_training_loop.png b/docs/source/_static/img/full_training_loop.png new file mode 100644 index 000000000..221e7c387 Binary files /dev/null and b/docs/source/_static/img/full_training_loop.png differ diff --git a/docs/source/_static/img/fused_backward_optimizer.png b/docs/source/_static/img/fused_backward_optimizer.png new file mode 100644 index 000000000..3c3d3593e Binary files /dev/null and b/docs/source/_static/img/fused_backward_optimizer.png differ diff --git a/docs/source/_static/img/fused_embedding_tables.png b/docs/source/_static/img/fused_embedding_tables.png new file mode 100644 index 000000000..6874e4d75 Binary files /dev/null and b/docs/source/_static/img/fused_embedding_tables.png differ diff --git a/docs/source/_static/img/model_parallel.png b/docs/source/_static/img/model_parallel.png new file mode 100644 index 000000000..81d18e9e1 Binary files /dev/null and b/docs/source/_static/img/model_parallel.png differ diff --git a/docs/source/_static/img/sharding.png b/docs/source/_static/img/sharding.png new file mode 100644 index 000000000..653eb7cdb Binary files /dev/null and b/docs/source/_static/img/sharding.png differ diff --git a/docs/source/_static/img/torchrec_forward.png b/docs/source/_static/img/torchrec_forward.png new file mode 100644 index 000000000..156f13e8f Binary files /dev/null and b/docs/source/_static/img/torchrec_forward.png differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html new file mode 100644 index 000000000..dbc08744a --- /dev/null +++ b/docs/source/_templates/layout.html @@ -0,0 +1,10 @@ +{% extends "!layout.html" %} + +{% block footer %} +{{ super() }} + + + +{% endblock %} diff --git a/docs/source/concepts.rst b/docs/source/concepts.rst new file mode 100644 index 000000000..b1d7f3776 --- /dev/null +++ b/docs/source/concepts.rst @@ -0,0 +1,311 @@ +.. meta:: + :description: TorchRec Concepts + :keywords: recommendation systems, sharding, distributed training, torchrec, embedding bags, embeddings, keyedjaggedtensor, row wise, table wise, column wise, table row wise, planner, sharder + +################### +TorchRec Concepts +################### + +In this section, we will learn about the key concepts of TorchRec, +designed to optimize large-scale recommendation systems using PyTorch. +We will learn how each concept works in detail and review how it is used +with the rest of TorchRec. + +TorchRec has specific input/output data types of its modules to +efficiently represent sparse features, including: + +- **JaggedTensor:** a wrapper around the lengths/offsets and values + tensors for a singular sparse feature. +- **KeyedJaggedTensor:** efficiently represent multiple sparse + features, can think of it as multiple ``JaggedTensor``\s. +- **KeyedTensor:** a wrapper around ``torch.Tensor`` that allows access + to tensor values through keys. + +With the goal of high performance and efficiency, the canonical +``torch.Tensor`` is highly inefficient for representing sparse data. +TorchRec introduces these new data types because they provide efficient +storage and representation of sparse input data. As you will see later +on, the ``KeyedJaggedTensor`` makes communication of input data in a +distributed environment very efficient leading to one of the key +performance advantages that TorchRec provides. + +In the end-to-end training loop, TorchRec comprises of the following +main components: + +- **Planner:** Takes in the configuration of embedding tables, + environment setup, and generates an optimized sharding plan for the + model. + +- **Sharder:** Shards model according to sharding plan with different + sharding strategies including data-parallel, table-wise, row-wise, + table-wise-row-wise, column-wise, and table-wise-column-wise + sharding. + +- **DistributedModelParallel:** Combines sharder, optimizer, and + provides an entry point into the training the model in a distributed + manner. + +************** +JaggedTensor +************** + +A ``JaggedTensor`` represents a sparse feature through lengths, values, +and offsets. It is called "jagged" because it efficiently represents +data with variable-length sequences. In contrast, a canonical +``torch.Tensor`` assumes that each sequence has the same length, which +is often not the case with real world data. A ``JaggedTensor`` +facilitates the representation of such data without padding making it +highly efficient. + +Key Components: + +- ``Lengths``: A list of integers representing the number of elements + for each entity. + +- ``Offsets``: A list of integers representing the starting index of + each sequence in the flattened values tensor. These provide an + alternative to lengths. + +- ``Values``: A 1D tensor containing the actual values for each entity, + stored contiguously. + +Here is a simple example demonstrating how each of the components would +look like: + +.. code:: python + + # User interactions: + # - User 1 interacted with 2 items + # - User 2 interacted with 3 items + # - User 3 interacted with 1 item + lengths = [2, 3, 1] + offsets = [0, 2, 5] # Starting index of each user's interactions + values = torch.Tensor([101, 102, 201, 202, 203, 301]) # Item IDs interacted with + jt = JaggedTensor(lengths=lengths, values=values) + # OR + jt = JaggedTensor(offsets=offsets, values=values) + +******************* +KeyedJaggedTensor +******************* + +A ``KeyedJaggedTensor`` extends the functionality of ``JaggedTensor`` by +introducing keys (which are typically feature names) to label different +groups of features, for example, user features and item features. This +is the data type used in ``forward`` of ``EmbeddingBagCollection`` and +``EmbeddingCollection`` as they are used to represent multiple features +in a table. + +A ``KeyedJaggedTensor`` has an implied batch size, which is the number +of features divided by the length of ``lengths`` tensor. The example +below has a batch size of 2. Similar to a ``JaggedTensor``, the +``offsets`` and ``lengths`` function in the same manner. You can also +access the ``lengths``, ``offsets``, and ``values`` of a feature by +accessing the key from the ``KeyedJaggedTensor``. + +.. code:: python + + keys = ["user_features", "item_features"] + # Lengths of interactions: + # - User features: 2 users, with 2 and 3 interactions respectively + # - Item features: 2 items, with 1 and 2 interactions respectively + lengths = [2, 3, 1, 2] + values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201]) + # Create a KeyedJaggedTensor + kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values) + # Access the features by key + print(kjt["user_features"]) + # Outputs user features + print(kjt["item_features"]) + +********* +Planner +********* + +The TorchRec planner helps determine the best sharding configuration for +a model. It evaluates multiple possibilities for sharding embedding +tables and optimizes for performance. The planner performs the +following: + +- Assesses the memory constraints of the hardware. +- Estimates compute requirements based on memory fetches, such as + embedding lookups. +- Addresses data-specific factors. +- Considers other hardware specifics, such as bandwidth, to generate an + optimal sharding plan. + +To ensure accurate consideration of these factors, the Planner can +incorporate data about the embedding tables, constraints, hardware +information, and topology to help in generating an optimal plan. + +***************************** +Sharding of EmbeddingTables +***************************** + +TorchRec sharder provides multiple sharding strategies for various use +cases, we outline some of the sharding strategies and how they work as +well as their benefits and limitations. Generally, we recommend using +the TorchRec planner to generate a sharding plan for you as it will find +the optimal sharding strategy for each embedding table in your model. + +Each sharding strategy determines how to do the table split, whether the +table should be cut up and how, whether to keep one or a few copies of +some tables, and so on. Each piece of the table from the outcome of +sharding, whether it is one embedding table or part of it, is referred +to as a shard. + +.. figure:: _static/img/sharding.png + :alt: Visualizing the difference of sharding types offered in TorchRec + :align: center + + *Figure 1: Visualizing the placement of table shards under different sharding schemes offered in TorchRec* + +Here is the list of all sharding types available in TorchRec: + +- Table-wise (TW): as the name suggests, embedding table is kept as a + whole piece and placed on one rank. + +- Column-wise (CW): the table is split along the ``emb_dim`` dimension, + for example, ``emb_dim=256`` is split into 4 shards: ``[64, 64, 64, + 64]``. + +- Row-wise (RW): the table is split along the ``hash_size`` dimension, + usually split evenly among all the ranks. + +- Table-wise-row-wise (TWRW): table is placed on one host, split + row-wise among the ranks on that host. + +- Grid-shard (GS): a table is CW sharded and each CW shard is placed + TWRW on a host. + +- Data parallel (DP): each rank keeps a copy of the table. + +Once sharded, the modules are converted to sharded versions of +themselves, known as ``ShardedEmbeddingCollection`` and +``ShardedEmbeddingBagCollection`` in TorchRec. These modules handle the +communication of input data, embedding lookups, and gradients. + +**************************************************** +Distributed Training with TorchRec Sharded Modules +**************************************************** + +With many sharding strategies available, how do we determine which one +to use? There is a cost associated with each sharding scheme, which in +conjunction with model size and number of GPUs determines which sharding +strategy is best for a model. + +Without sharding, where each GPU keeps a copy of the embedding table +(DP), the main cost is computation in which each GPU looks up the +embedding vectors in its memory in the forward pass and updates the +gradients in the backward pass. + +With sharding, there is an added communication cost: each GPU needs to +ask the other GPUs for embedding vector lookup and communicate the +gradients computed as well. This is typically referred to as ``all2all`` +communication. In TorchRec, for input data on a given GPU, we determine +where the embedding shard for each part of the data is located and send +it to the target GPU. That target GPU then returns the embedding vectors +back to the original GPU. In the backward pass, the gradients are sent +back to the target GPU and the shards are updated accordingly with the +optimizer. + +As described above, sharding requires us to communicate the input data +and embedding lookups. TorchRec handles this in three main stages, we +will refer to this as the sharded embedding module forward that is used +in training and inference of a TorchRec model: + +- Feature All to All/Input distribution (``input_dist``) + + - Communicate input data (in the form of a ``KeyedJaggedTensor``) to + the appropriate device containing relevant embedding table shard + +- Embedding Lookup + + - Lookup embeddings with new input data formed after feature all to + all exchange + +- Embedding All to All/Output Distribution (``output_dist``) + + - Communicate embedding lookup data back to the appropriate device + that asked for it (in accordance with the input data the device + received) + +- The backward pass does the same operations but in reverse order. + +The diagram below demonstrates how it works: + +.. figure:: _static/img/torchrec_forward.png + :alt: Visualizing the forward pass including the input_dist, lookup, and output_dist of a sharded TorchRec module + :align: center + + *Figure 2: Forward pass of a table wise sharded table including the input_dist, lookup, and output_dist of a sharded TorchRec module* + +************************** +DistributedModelParallel +************************** + +All of the above culminates into the main entrypoint that TorchRec uses +to shard and integrate the plan. At a high level, +``DistributedModelParallel`` does the following: + +- Initializes the environment by setting up process groups and + assigning device type. + +- Uses default sharders if no sharders are provided, the default includes + ``EmbeddingBagCollectionSharder``. + +- Takes in the provided sharding plan, if none is provided, it + generates one. + +- Creates a sharded version of modules and replaces the original + modules with them, for example, converts ``EmbeddingCollection`` to + ``ShardedEmbeddingCollection``. + +- By default, wraps the ``DistributedModelParallel`` with + ``DistributedDataParallel`` to make the module both model and data + parallel. + +*********** +Optimizer +*********** + +TorchRec modules provide a seamless API to fuse the backwards pass and +optimizer step in training, providing a significant optimization in +performance and decreasing the memory used, alongside granularity in +assigning distinct optimizers to distinct model parameters. + +.. figure:: _static/img/fused_backward_optimizer.png + :alt: Visualizing fusing of optimizer in backward to update sparse embedding table + :align: center + + *Figure 3: Fusing embedding backward with sparse optimizer* + +*********** +Inference +*********** + +Inference environments are different from training, they are very +sensitive to performance and the size of the model. There are two key +differences TorchRec inference optimizes for: + +- **Quantization:** inference models are quantized for lower latency + and reduced model size. This optimization lets us use as few devices + as possible for inference to minimize latency. + +- **C++ environment:** to minimize latency even further, the model is + ran in a C++ environment. + +TorchRec provides the following to convert a TorchRec model into being +inference ready: + +- APIs for quantizing the model, including optimizations automatically + with FBGEMM TBE +- Sharding embeddings for distributed inference +- Compiling the model to TorchScript (compatible in C++) + +********* +See Also +********* + +- `TorchRec Interactive Notebook using the concepts + `_ diff --git a/docs/source/conf.py b/docs/source/conf.py index 4011be287..0f64cefe7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,19 +30,34 @@ # -- Project information ----------------------------------------------------- project = "TorchRec" -copyright = "2022, Meta" +copyright = "2024, Meta" author = "Meta" +try: + # pyre-ignore + version = "1.0.0" # TODO: Hardcode stable version for now +except Exception: + # when run internally, we don't have a version yet + version = "0.0.0" # The full version, including alpha/beta/rc tags -release = "0.0.1" - +# First 3 as format is 0.x.x.* +release = ".".join(version.split(".")[:3]) # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"] +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx_copybutton", + "sphinx.ext.mathjax", + "sphinx_design", +] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -62,7 +77,21 @@ html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "pytorch_project": "torchrec", + "display_version": True, + "logo_only": True, + "collapse_navigation": False, + "includehidden": True, +} + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] + +html_css_files = ["css/custom.css"] diff --git a/docs/source/datatypes-api-reference.rst b/docs/source/datatypes-api-reference.rst new file mode 100644 index 000000000..8454a15f2 --- /dev/null +++ b/docs/source/datatypes-api-reference.rst @@ -0,0 +1,22 @@ +Data Types +------------------- + + +TorchRec contains data types for representing embedding, otherwise known as sparse features. +Sparse features are typically indices that are meant to be fed into embedding tables. For a given +batch, the number of embedding lookup indices are variable. Therefore, there is a need for a **jagged** +dimension to represent the variable amount of embedding lookup indices for a batch. + +This section covers the classes for the 3 TorchRec data types for representing sparse features: +**JaggedTensor**, **KeyedJaggedTensor**, and **KeyedTensor**. + +.. automodule:: torchrec.sparse.jagged_tensor + +.. autoclass:: JaggedTensor + :members: + +.. autoclass:: KeyedJaggedTensor + :members: + +.. autoclass:: KeyedTensor + :members: diff --git a/docs/source/high-level-arch.rst b/docs/source/high-level-arch.rst new file mode 100644 index 000000000..2c2263b6c --- /dev/null +++ b/docs/source/high-level-arch.rst @@ -0,0 +1,129 @@ +.. meta:: + :description: TorchRec High Level Architecture + :keywords: recommendation systems, sharding, distributed training, torchrec, architecture + +################################## + TorchRec High Level Architecture +################################## + +In this section, you will learn about the high-level architecture of +TorchRec, designed to optimize large-scale recommendation systems using +PyTorch. You will learn how TorchRec employs model parallelism to +distribute complex models across multiple GPUs, enhancing memory +management and GPU utilization, as well as get introduced to TorchRec's +base components and sharding strategies. + +In effect, TorchRec provides parallelism primitives allowing hybrid data +parallelism/model parallelism, embedding table sharding, planner to +generate sharding plans, pipelined training, and more. + +**************************************************** + TorchRec's Parallelism Strategy: Model Parallelism +**************************************************** + +As modern deep learning models have scaled, distributed deep learning +has become required to successfully train models in sufficient time. In +this paradigm, two main approaches have been developed: data parallelism +and model parallelism. TorchRec focuses on the latter for the sharding +of embedding tables. + +.. figure:: _static/img/model_parallel.png + :alt: Visualizing the difference of sharding a model in model parallel or data parallel approach + :align: center + + *Figure 1. Comparison between model parallelism and data parallelism approach* + +As you can see in the diagram above, model parallelism and data +parallelism are two approaches to distribute workloads across multiple +GPUs, + +- **Model Parallelism** + + - Divide the model into segments and distribute them across GPUs + - Each segment processes data independently + - Suitable for large models that don't fit on a single GPU + +- **Data Parallel** + + - Distribute the copies of entire model on each GPU + - Each GPU processes a subset of the data and contributes to the + overall computation + - Effecive for models that fit on single GPU but need to handle + large datasets + +- **Benefits of Model Parallelism** + + - Optimizes memory usage and computational efficiency for large + models + - Particularly beneficial for recommendation systems with large + embedding tables + - Enables parallel computation of embeddings in DLRM-type + architectures + +****************** + Embedding Tables +****************** + +For TorchRec to figure out what to recommend, we need to be able to +represent entities and their relationships, this is what embeddings are +used for. Embeddings are vectors of real numbers in a high dimensional +space used to represent meaning in complex data like words, images, or +users. An embedding table is an aggregation of multiple embeddings into +one matrix. Most commonly, embedding tables are represented as a 2D +matrix with dimensions (B, N). + +- *B* is the number of embeddings stored by the table +- *N* is number of dimensions per embedding. + +Each of *B* can also be referred to as an ID (representing information +such as movie title, user, ad, and so on), when accessing an ID we are +returned the corresponding embedding vector which has size of embedding +dimension *N*. + +There is also the choice of pooling embeddings, often, we’re looking up +multiple rows for a given feature which gives rise to the question of +what we do with looking up multiple embedding vectors. Pooling is a +common technique where we combine the embedding vectors, usually through +sum or mean of the rows, to produce one embedding vector. This is the +main difference between the PyTorch ``nn.Embedding`` and +``nn.EmbeddingBag``. + +PyTorch represents embeddings through ``nn.Embedding`` and +``nn.EmbeddingBag``. Building on these modules, TorchRec introduces +``EmbeddingCollection`` and ``EmbeddingBagCollection``, which are +collections of the corresponding PyTorch modules. This extension enables +TorchRec to batch tables and perform lookups on multiple embeddings in a +single kernel call, improving efficiency. + +Here is the end-to-end flow diagram that describes how embeddings are +used in the training process for recommendation models: + +.. figure:: _static/img/full_training_loop.png + :alt: Demonstrating the full training loop from embedding lookup to optimizer update in backward + :align: center + + *Figure 2. TorchRec End-to-end Embedding Flow* + +In the diagram above, we show the general TorchRec end to end embedding +lookup process, + +- In the forward pass we do the embedding lookup and pooling +- In the backward pass we compute the gradients of the output lookups + and pass them into the optimizer to update the embedding tables + +**Note here, the embeddings gradients are grayed out since we do not +fully materialize these into memory and instead fuse them with the +optimizer update. This results in a significant memory reduction which +we detail later in the optimizer concepts section.** + +We recommend going through the TorchRec Concepts page to get a +understanding of the fundamentals of how everything ties together +end-to-end. It contains lots of useful information to get the most out +of TorchRec. + +********** + See also +********** + +- `What is Distributed Data Parallel (DDP) Tutorial + `_ diff --git a/docs/source/index.rst b/docs/source/index.rst index 09f3a6f2c..c6fa49282 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,51 +3,109 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. +.. meta:: + :description: TorchRec documentation homepage + :keywords: recommendation systems, sharding, distributed training + Welcome to the TorchRec documentation! ====================================== -TorchRec is a PyTorch domain library built to provide common -sparsity & parallelism primitives needed for large-scale recommender -systems (RecSys). It allows authors to train models with large -embedding tables sharded across many GPUs. +TorchRec is a specialized library within the PyTorch ecosystem, +tailored for building, scaling, and deploying large-scale +**recommendation systems**, a niche not directly addressed by standard +PyTorch. TorchRec offers advanced features such as complex sharding +techniques for massive embedding tables, and enhanced distributed +training capabilities. + +Getting Started +--------------- + +Topics in this section will help you get started with TorchRec. + +.. grid:: 3 + + .. grid-item-card:: :octicon:`file-code;1em` + TorchRec Overview + :img-top: _static/img/card-background.svg + :link: overview.html + :link-type: url + + A short intro to TorchRec and why you need it. + + .. grid-item-card:: :octicon:`file-code;1em` + Set up TorchRec + :img-top: _static/img/card-background.svg + :link: setup-torchrec.html + :link-type: url + + Learn how to install and start using TorchRec + in your environment. + + .. grid-item-card:: :octicon:`file-code;1em` + Getting Started with TorchRec Tutorial + :img-top: _static/img/card-background.svg + :link: https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb + :link-type: url + + Follow our interactive step-by-step tutorial + to learn how to use TorchRec in a real-life + example. + + -For installation instructions, visit +How to Contribute +----------------- -https://github.com/pytorch/torchrec#readme +We welcome contributions and feedback from the PyTorch community! +If you are interested in helping improve the TorchRec project, here is +how you can contribute: -Tutorial --------- -In this tutorial, we introduce the primary torchRec -API called DistributedModelParallel, or DMP. -Like pytorch’s DistributedDataParallel, -DMP wraps a model to enable distributed training. +1. **Visit Our** `GitHub Repository `__: + There you can find the source code, issues, and ongoing projects. -* `Tutorial Source `_ -* Open in `Google Colab `_ +1. **Submit Feedback or Issues**: If you encounter any bugs or have + suggestions for improvements, please submit an issue through the + `GitHub issue tracker `__. -TorchRec API ------------- +1. **Propose changes**: Fork the repository and submit pull requests. + Whether it's fixing a bug, adding new features, or improving + documentation, your contributions are always welcome! Please make sure to + review our `CONTRIBUTING.md `__ + +| +| + +.. container:: center-content + + .. button-link:: https://github.com/pytorch/torchrec + :color: info + + :octicon:`mark-github` Go to TorchRec Repo + + +.. toctree:: + :maxdepth: 1 + :caption: Introduction + :hidden: + + overview.rst + high-level-arch.rst + concepts.rst + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + :hidden: + + setup-torchrec.rst .. toctree:: :maxdepth: 2 - :caption: Contents: - - torchrec.datasets.rst - torchrec.datasets.scripts.rst - torchrec.distributed.rst - torchrec.distributed.planner.rst - torchrec.distributed.sharding.rst - torchrec.fx.rst - torchrec.inference.rst - torchrec.models.rst - torchrec.modules.rst - torchrec.optim.rst - torchrec.quant.rst - torchrec.sparse.rst - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + :caption: API References + :hidden: + + datatypes-api-reference.rst + modules-api-reference.rst + planner-api-reference.rst + model-parallel-api-reference.rst + inference-api-reference.rst diff --git a/docs/source/inference-api-reference.rst b/docs/source/inference-api-reference.rst new file mode 100644 index 000000000..0e10067c7 --- /dev/null +++ b/docs/source/inference-api-reference.rst @@ -0,0 +1,19 @@ +Inference +---------------------------------- + +TorchRec provides easy-to-use APIs for transforming an authored TorchRec model +into an optimized inference model for distributed inference, via eager module swaps. + +This transforms TorchRec modules like ``EmbeddingBagCollection`` in the model to +a quantized, sharded version that can be compiled using torch.fx and TorchScript +for inference in a C++ environment. + +The intended use is calling ``quantize_inference_model`` on the model followed by +``shard_quant_model``. + +.. codeblock:: + +.. automodule:: torchrec.inference.modules + +.. autofunction:: quantize_inference_model +.. autofunction:: shard_quant_model diff --git a/docs/source/model-parallel-api-reference.rst b/docs/source/model-parallel-api-reference.rst new file mode 100644 index 000000000..196fd27fb --- /dev/null +++ b/docs/source/model-parallel-api-reference.rst @@ -0,0 +1,10 @@ +Model Parallel +---------------------------------- + +``DistributedModelParallel`` is the main API for distributed training with TorchRec optimizations. + + +.. automodule:: torchrec.distributed.model_parallel + +.. autoclass:: DistributedModelParallel + :members: diff --git a/docs/source/modules-api-reference.rst b/docs/source/modules-api-reference.rst new file mode 100644 index 000000000..98cff7ad8 --- /dev/null +++ b/docs/source/modules-api-reference.rst @@ -0,0 +1,30 @@ +Modules +---------------------------------- + +Standard TorchRec modules represent collections of embedding tables: + +* ``EmbeddingBagCollection`` is a collection of ``torch.nn.EmbeddingBag`` +* ``EmbeddingCollection`` is a collection of ``torch.nn.Embedding`` + +These modules are constructed through standardized config classes: + +* ``EmbeddingBagConfig`` for ``EmbeddingBagCollection`` +* ``EmbeddingConfig`` for ``EmbeddingCollection`` + +.. automodule:: torchrec.modules.embedding_configs + +.. autoclass:: EmbeddingBagConfig + :show-inheritance: + +.. autoclass:: EmbeddingConfig + :show-inheritance: + +.. autoclass:: BaseEmbeddingConfig + +.. automodule:: torchrec.modules.embedding_modules + +.. autoclass:: EmbeddingBagCollection + :members: + +.. autoclass:: EmbeddingCollection + :members: diff --git a/docs/source/overview.rst b/docs/source/overview.rst new file mode 100644 index 000000000..0098c14df --- /dev/null +++ b/docs/source/overview.rst @@ -0,0 +1,23 @@ +.. _overview_label: + +================== +TorchRec Overview +================== + +TorchRec is the PyTorch recommendation system library, designed to provide common primitives +for creating state-of-the-art personalization models and a path to production. TorchRec is +widely adopted in many Meta production recommendation system models for training and inference workflows. + +Why TorchRec? +------------------ + +TorchRec is designed to address the unique challenges of building, scaling and deploying massive, +large-scale recommendation system models, which is not a focus of regular PyTorch. More specifically, +TorchRec provides the following primitives for general recommendation systems: + +- **Specialized Components**: TorchRec provides simplistic, specialized modules that are common in authoring recommendation systems, with a focus on embedding tables +- **Advanced Sharding Techniques**: TorchRec provides flexible and customizable methods for sharding massive embedding tables: Row-Wise, Column-Wise, Table-Wise, and so on. TorchRec can automatically determine the best plan for a device topology for efficient training and memory balance +- **Distributed Training**: While PyTorch supports basic distributed training, TorchRec extends these capabilities with more sophisticated model parallelism techniques specifically designed for the massive scale of recommendation systems +- **Incredibly Optimized**: TorchRec training and inference components are incredibly optimized on top of FBGEMM. After all, TorchRec powers some of the largest recommendation system models at Meta +- **Frictionless Path to Deployment**: TorchRec provides simple APIs for transforming a trained model for inference and loading it into a C++ environment for the most optimal inference model +- **Integration with PyTorch Ecosystem**: TorchRec is built on top of PyTorch, meaning it integrates seamlessly with existing PyTorch code, tools, and workflows. This allows developers to leverage their existing knowledge and codebase while utilizing advanced features for recommendation systems. By being a part of the PyTorch ecosystem, TorchRec benefits from the robust community support, continuous updates, and improvements that come with PyTorch. diff --git a/docs/source/planner-api-reference.rst b/docs/source/planner-api-reference.rst new file mode 100644 index 000000000..e6cc7f4d2 --- /dev/null +++ b/docs/source/planner-api-reference.rst @@ -0,0 +1,50 @@ +Planner +---------------------------------- + +The TorchRec Planner is responsible for determining the most performant, balanced +sharding plan for distributed training and inference. + +The main API for generating a sharding plan is ``EmbeddingShardingPlanner.plan`` + +.. automodule:: torchrec.distributed.types + +.. autoclass:: ShardingPlan + :members: + +.. automodule:: torchrec.distributed.planner.planners + +.. autoclass:: EmbeddingShardingPlanner + :members: + +.. automodule:: torchrec.distributed.planner.enumerators + +.. autoclass:: EmbeddingEnumerator + :members: + +.. automodule:: torchrec.distributed.planner.partitioners + +.. autoclass:: GreedyPerfPartitioner + :members: + + +.. automodule:: torchrec.distributed.planner.storage_reservations + +.. autoclass:: HeuristicalStorageReservation + :members: + +.. automodule:: torchrec.distributed.planner.proposers + +.. autoclass:: GreedyProposer + :members: + + +.. automodule:: torchrec.distributed.planner.shard_estimators + +.. autoclass:: EmbeddingPerfEstimator + :members: + + +.. automodule:: torchrec.distributed.planner.shard_estimators + +.. autoclass:: EmbeddingStorageEstimator + :members: diff --git a/docs/source/setup-torchrec.rst b/docs/source/setup-torchrec.rst new file mode 100644 index 000000000..7a2cbf969 --- /dev/null +++ b/docs/source/setup-torchrec.rst @@ -0,0 +1,139 @@ + +=================== +Setting up TorchRec +=================== + +In this section, we will: + +* Understand requirements for using TorchRec +* Set up an environment to integrate TorchRec +* Run basic TorchRec code + + +System Requirements +------------------- + +TorchRec is routinely tested on AWS Linux only and should work in similar environments. +Below demonstrates the compatability matrix that is currently tested: + +.. list-table:: + :widths: 25 75 + :header-rows: 0 + + * - Python Version + - 3.9, 3.10, 3.11, 3.12 + * - Compute Platform + - CPU, CUDA 11.8, CUDA 12.1, CUDA 12.4 + +Aside from those requirements, TorchRec's core dependencies are PyTorch and FBGEMM. +If your system is compatible with both libraries generally, then it should be sufficient for TorchRec. + +* `PyTorch requirements `_ +* `FBGEMM requirements `_ + + +Version Compatability +--------------------- + +TorchRec and FBGEMM have matching version numbers that are tested together upon release: + +* TorchRec 1.0 is compatible with FBGEMM 1.0 +* TorchRec 0.8 is compatible with FBGEMM 0.8 +* TorchRec 0.8 may not be compatible with FBGEMM 0.7 + +Furthermore, TorchRec and FBGEMM are released only when a new PyTorch release happens. +Therefore, specific versions of TorchRec and FBGEMM should correspond to a specific PyTorch version: + +* TorchRec 1.0 is compatible with PyTorch 2.5 +* TorchRec 0.8 is compatible with PyTorch 2.4 +* TorchRec 0.8 may not be compatible with PyTorch 2.3 + +Installation +------------ +Below we show installations for CUDA 12.1 as an example. For CPU, CUDA 11.8, or CUDA 12.4, swap ``cu121`` for ``cpu``, ``cu118``, or ``cu124`` respectively. + +.. tab-set:: + + .. tab-item:: **Stable via pytorch.org** + + .. code-block:: bash + + pip install torch --index-url https://download.pytorch.org/whl/cu121 + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121 + pip install torchmetrics==1.0.3 + pip install torchrec --index-url https://download.pytorch.org/whl/cu121 + + .. tab-item:: **Stable via PyPI (Only for CUDA 12.4)** + + .. code-block:: bash + + pip install torch + pip install fbgemm-gpu + pip install torchrec + + .. tab-item:: **Nightly** + + .. code-block:: bash + + pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install torchmetrics==1.0.3 + pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121 + + .. tab-item:: **Building From Source** + + You also have the ability to build TorchRec from source to develop with the latest + changes in TorchRec. To build from source, check out this `reference `_. + + +Run a Simple TorchRec Example +------------------------------ +Now that we have TorchRec properly set up, let's run some TorchRec code! +Below, we'll run a simple forward pass with TorchRec data types: ``KeyedJaggedTensor`` and ``EmbeddingBagCollection``: + +.. code-block:: python + + import torch + + import torchrec + from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + ebc = torchrec.EmbeddingBagCollection( + device="cpu", + tables=[ + torchrec.EmbeddingBagConfig( + name="product_table", + embedding_dim=16, + num_embeddings=4096, + feature_names=["product"], + pooling=torchrec.PoolingType.SUM, + ), + torchrec.EmbeddingBagConfig( + name="user_table", + embedding_dim=16, + num_embeddings=4096, + feature_names=["user"], + pooling=torchrec.PoolingType.SUM, + ) + ] + ) + + product_jt = JaggedTensor( + values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1]) + ) + user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2])) + + # Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt? + kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt}) + + print("Call EmbeddingBagCollection Forward: ", ebc(kjt)) + +Save the above code to a file named ``torchrec_example.py``. Then, you should be able to +execute it from your terminal with: + +.. code-block:: bash + + python torchrec_example.py + +You should see the output ``KeyedTensor`` with the resulting embeddings. +Congrats! You have correctly installed and ran your first TorchRec program! diff --git a/docs/source/torchrec.datasets.rst b/docs/source/torchrec.datasets.rst deleted file mode 100644 index a16be3c4c..000000000 --- a/docs/source/torchrec.datasets.rst +++ /dev/null @@ -1,36 +0,0 @@ -torchrec.datasets -================= - -.. automodule:: torchrec.datasets - -torchrec.datasets.criteo ------------------------- - -.. automodule:: torchrec.datasets.criteo - :members: - :undoc-members: - :show-inheritance: - -torchrec.datasets.movielens ---------------------------- - -.. automodule:: torchrec.datasets.movielens - :members: - :undoc-members: - :show-inheritance: - -torchrec.datasets.random ------------------------- - -.. automodule:: torchrec.datasets.random - :members: - :undoc-members: - :show-inheritance: - -torchrec.datasets.utils ------------------------ - -.. automodule:: torchrec.datasets.utils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.datasets.scripts.rst b/docs/source/torchrec.datasets.scripts.rst deleted file mode 100644 index 947250d67..000000000 --- a/docs/source/torchrec.datasets.scripts.rst +++ /dev/null @@ -1,22 +0,0 @@ -torchrec.datasets.scripts -========================= - -.. automodule:: torchrec.datasets.scripts - - - -torchrec.datasets.scripts.contiguous\_preproc\_criteo ------------------------------------------------------ - -.. automodule:: torchrec.datasets.scripts.contiguous_preproc_criteo - :members: - :undoc-members: - :show-inheritance: - -torchrec.datasets.scripts.npy\_preproc\_criteo ----------------------------------------------- - -.. automodule:: torchrec.datasets.scripts.npy_preproc_criteo - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.distributed.planner.rst b/docs/source/torchrec.distributed.planner.rst deleted file mode 100644 index 82200125e..000000000 --- a/docs/source/torchrec.distributed.planner.rst +++ /dev/null @@ -1,93 +0,0 @@ -torchrec.distributed.planner -============================ - -.. automodule:: torchrec.distributed.planner - - -torchrec.distributed.planner.constants --------------------------------------- - -.. automodule:: torchrec.distributed.planner.constants - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.enumerators ----------------------------------------- - -.. automodule:: torchrec.distributed.planner.enumerators - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.partitioners ------------------------------------------ - -.. automodule:: torchrec.distributed.planner.partitioners - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.perf\_models ------------------------------------------ - -.. automodule:: torchrec.distributed.planner.perf_models - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.planners -------------------------------------- - -.. automodule:: torchrec.distributed.planner.planners - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.proposers --------------------------------------- - -.. automodule:: torchrec.distributed.planner.proposers - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.shard\_estimators ----------------------------------------------- - -.. automodule:: torchrec.distributed.planner.shard_estimators - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.stats ----------------------------------- - -.. automodule:: torchrec.distributed.planner.stats - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.storage\_reservations --------------------------------------------------- - -.. automodule:: torchrec.distributed.planner.storage_reservations - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.types ----------------------------------- - -.. automodule:: torchrec.distributed.planner.types - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.planner.utils ----------------------------------- - -.. automodule:: torchrec.distributed.planner.utils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.distributed.rst b/docs/source/torchrec.distributed.rst deleted file mode 100644 index 9fdcb4282..000000000 --- a/docs/source/torchrec.distributed.rst +++ /dev/null @@ -1,126 +0,0 @@ -torchrec.distributed -==================== - -.. automodule:: torchrec.distributed - - -torchrec.distributed.collective\_utils --------------------------------------- - -.. automodule:: torchrec.distributed.collective_utils - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.comm -------------------------- - -.. automodule:: torchrec.distributed.comm - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.comm\_ops ------------------------------- - -.. automodule:: torchrec.distributed.comm_ops - :members: - :undoc-members: - :show-inheritance: - - -torchrec.distributed.dist\_data -------------------------------- - -.. automodule:: torchrec.distributed.dist_data - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.embedding ------------------------------- - -.. automodule:: torchrec.distributed.embedding - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.embedding\_lookup --------------------------------------- - -.. automodule:: torchrec.distributed.embedding_lookup - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.embedding\_sharding ----------------------------------------- - -.. automodule:: torchrec.distributed.embedding_sharding - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.embedding\_types -------------------------------------- - -.. automodule:: torchrec.distributed.embedding_types - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.embeddingbag ---------------------------------- - -.. automodule:: torchrec.distributed.embeddingbag - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.grouped\_position\_weighted ------------------------------------------------- - -.. automodule:: torchrec.distributed.grouped_position_weighted - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.model\_parallel ------------------------------------- - -.. automodule:: torchrec.distributed.model_parallel - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.quant\_embeddingbag ----------------------------------------- - -.. automodule:: torchrec.distributed.quant_embeddingbag - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.train\_pipeline ------------------------------------- - -.. automodule:: torchrec.distributed.train_pipeline - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.types --------------------------- - -.. automodule:: torchrec.distributed.types - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.utils --------------------------- - -.. automodule:: torchrec.distributed.utils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.distributed.sharding.rst b/docs/source/torchrec.distributed.sharding.rst deleted file mode 100644 index d0e7162ae..000000000 --- a/docs/source/torchrec.distributed.sharding.rst +++ /dev/null @@ -1,63 +0,0 @@ -torchrec.distributed.sharding -============================= - -.. automodule:: torchrec.distributed.sharding - - -torchrec.distributed.sharding.cw\_sharding ------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.cw_sharding - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.dist\_data -------------------------------- - -.. automodule:: torchrec.distributed.dist_data - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.sharding.dp\_sharding ------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.dp_sharding - :members: - :undoc-members: - :show-inheritance: - - -torchrec.distributed.sharding.rw\_sharding ------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.rw_sharding - :members: - :undoc-members: - :show-inheritance: - - -torchrec.distributed.sharding.tw\_sharding ------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.tw_sharding - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.sharding.twcw\_sharding --------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.twcw_sharding - :members: - :undoc-members: - :show-inheritance: - -torchrec.distributed.sharding.twrw\_sharding --------------------------------------------- - -.. automodule:: torchrec.distributed.sharding.twrw_sharding - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.fx.rst b/docs/source/torchrec.fx.rst deleted file mode 100644 index b06656fbc..000000000 --- a/docs/source/torchrec.fx.rst +++ /dev/null @@ -1,20 +0,0 @@ -torchrec.fx -=========== - -.. automodule:: torchrec.fx - -torchrec.fx.tracer ------------------- - -.. automodule:: torchrec.fx.tracer - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.fx - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.inference.rst b/docs/source/torchrec.inference.rst deleted file mode 100644 index 846a520ec..000000000 --- a/docs/source/torchrec.inference.rst +++ /dev/null @@ -1,28 +0,0 @@ -torchrec.inference -=========== - -.. automodule:: torchrec.inference - -torchrec.inference.model_packager ------------------- - -.. automodule:: torchrec.inference.model_packager - :members: - :undoc-members: - :show-inheritance: - -torchrec.inference.modules ------------------- - -.. automodule:: torchrec.inference.modules - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.inference - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.models.rst b/docs/source/torchrec.models.rst deleted file mode 100644 index 4d3ad3040..000000000 --- a/docs/source/torchrec.models.rst +++ /dev/null @@ -1,28 +0,0 @@ -torchrec.models -=============== - -.. automodule:: torchrec.models - -torchrec.models.deepfm ----------------------- - -.. automodule:: torchrec.models.deepfm - :members: - :undoc-members: - :show-inheritance: - -torchrec.models.dlrm --------------------- - -.. automodule:: torchrec.models.dlrm - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.models - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.modules.rst b/docs/source/torchrec.modules.rst deleted file mode 100644 index 95dcf35db..000000000 --- a/docs/source/torchrec.modules.rst +++ /dev/null @@ -1,85 +0,0 @@ -torchrec.modules -================ - -.. automodule:: torchrec.modules - -torchrec.modules.activation ---------------------------- - -.. automodule:: torchrec.modules.activation - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.crossnet -------------------------- - -.. automodule:: torchrec.modules.crossnet - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.deepfm ------------------------ - -.. automodule:: torchrec.modules.deepfm - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.embedding\_configs ------------------------------------ - -.. automodule:: torchrec.modules.embedding_configs - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.embedding\_modules ------------------------------------ - -.. automodule:: torchrec.modules.embedding_modules - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.feature\_processor ------------------------------------ - -.. automodule:: torchrec.modules.feature_processor - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.lazy\_extension --------------------------------- - -.. automodule:: torchrec.modules.lazy_extension - :members: - :undoc-members: - :show-inheritance: - -torchrec.modules.mlp --------------------- - -.. automodule:: torchrec.modules.mlp - :members: - :undoc-members: - :show-inheritance: - - -torchrec.modules.utils ----------------------- - -.. automodule:: torchrec.modules.utils - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.modules - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.optim.rst b/docs/source/torchrec.optim.rst deleted file mode 100644 index df3a147d0..000000000 --- a/docs/source/torchrec.optim.rst +++ /dev/null @@ -1,44 +0,0 @@ -torchrec.optim -============== - -.. automodule:: torchrec.optim - -torchrec.optim.clipping ------------------------ - -.. automodule:: torchrec.optim.clipping - :members: - :undoc-members: - :show-inheritance: - -torchrec.optim.fused --------------------- - -.. automodule:: torchrec.optim.fused - :members: - :undoc-members: - :show-inheritance: - -torchrec.optim.keyed --------------------- - -.. automodule:: torchrec.optim.keyed - :members: - :undoc-members: - :show-inheritance: - -torchrec.optim.warmup ---------------------- - -.. automodule:: torchrec.optim.warmup - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.optim - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.quant.rst b/docs/source/torchrec.quant.rst deleted file mode 100644 index df4b34bd4..000000000 --- a/docs/source/torchrec.quant.rst +++ /dev/null @@ -1,20 +0,0 @@ -torchrec.quant -============== - -.. automodule:: torchrec.quant - -torchrec.quant.embedding\_modules ---------------------------------- - -.. automodule:: torchrec.quant.embedding_modules - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.quant - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/torchrec.sparse.rst b/docs/source/torchrec.sparse.rst deleted file mode 100644 index efbd3eaec..000000000 --- a/docs/source/torchrec.sparse.rst +++ /dev/null @@ -1,20 +0,0 @@ -torchrec.sparse -=============== - -.. automodule:: torchrec.sparse - -torchrec.sparse.jagged\_tensor ------------------------------- - -.. automodule:: torchrec.sparse.jagged_tensor - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: torchrec.sparse - :members: - :undoc-members: - :show-inheritance: diff --git a/examples/bert4rec/bert4rec_main.py b/examples/bert4rec/bert4rec_main.py index 1f3f1319b..f5c8ccff5 100644 --- a/examples/bert4rec/bert4rec_main.py +++ b/examples/bert4rec/bert4rec_main.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 @@ -26,6 +28,7 @@ from torchrec.distributed.model_parallel import DistributedModelParallel as DMP from torchrec.distributed.types import ModuleSharder from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from tqdm import tqdm @@ -497,11 +500,12 @@ def main(argv: List[str]) -> None: ], ) dense_optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), + dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: optim.Adam( params, lr=args.lr, weight_decay=args.weight_decay ), ) + optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) else: device_ids = [rank] if backend == "nccl" else None diff --git a/examples/bert4rec/bert4rec_metrics.py b/examples/bert4rec/bert4rec_metrics.py index 09140d62e..6ad244fdc 100644 --- a/examples/bert4rec/bert4rec_metrics.py +++ b/examples/bert4rec/bert4rec_metrics.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Dict, List import torch diff --git a/examples/bert4rec/data/bert4rec_movielens_datasets.py b/examples/bert4rec/data/bert4rec_movielens_datasets.py index fb912126d..18144cc48 100644 --- a/examples/bert4rec/data/bert4rec_movielens_datasets.py +++ b/examples/bert4rec/data/bert4rec_movielens_datasets.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import random from collections import Counter from pathlib import Path diff --git a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py index daca1c4e7..06c8deef8 100644 --- a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py +++ b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from ..bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe diff --git a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py index af7a3b22f..a9303d218 100644 --- a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py +++ b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Dict, Tuple import pandas as pd diff --git a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py index 0a5526af3..a47c0fe7d 100644 --- a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py +++ b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from ...data.bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe diff --git a/examples/bert4rec/models/bert4rec.py b/examples/bert4rec/models/bert4rec.py index 0ca8b83b0..b90ac4305 100644 --- a/examples/bert4rec/models/bert4rec.py +++ b/examples/bert4rec/models/bert4rec.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import math diff --git a/examples/bert4rec/models/tests/test_bert4rec.py b/examples/bert4rec/models/tests/test_bert4rec.py index ba46c299d..5149f188b 100644 --- a/examples/bert4rec/models/tests/test_bert4rec.py +++ b/examples/bert4rec/models/tests/test_bert4rec.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import unittest diff --git a/examples/bert4rec/tests/test_bert4rec_main.py b/examples/bert4rec/tests/test_bert4rec_main.py index d53198c6e..705f4b64c 100644 --- a/examples/bert4rec/tests/test_bert4rec_main.py +++ b/examples/bert4rec/tests/test_bert4rec_main.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os import tempfile import unittest diff --git a/examples/datasets/README b/examples/datasets/README deleted file mode 100644 index 74210008a..000000000 --- a/examples/datasets/README +++ /dev/null @@ -1 +0,0 @@ -Datasets under this directory are prototyping with libaries under active development (e.g. TorchArrow DataFrame) which may not have best performance yet, and subject to change. diff --git a/examples/datasets/criteo_dataframes.py b/examples/datasets/criteo_dataframes.py deleted file mode 100644 index dec708602..000000000 --- a/examples/datasets/criteo_dataframes.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Iterable, List, Tuple, Union - -import numpy as np -import torch.utils.data.datapipes as dp -import torcharrow as ta -import torcharrow.dtypes as dt -from torch.utils.data import IterDataPipe -from torchrec.datasets.criteo import ( - CAT_FEATURE_COUNT, - CriteoIterDataPipe, - DEFAULT_CAT_NAMES, - DEFAULT_INT_NAMES, - INT_FEATURE_COUNT, -) -from torchrec.datasets.utils import safe_cast - -DTYPE = dt.Struct( - [ - dt.Field("labels", dt.int8), - dt.Field( - "dense_features", - dt.Struct( - [ - dt.Field(int_name, dt.Int32(nullable=True)) - for int_name in DEFAULT_INT_NAMES - ] - ), - ), - dt.Field( - "sparse_features", - dt.Struct( - [ - dt.Field(cat_name, dt.Int32(nullable=True)) - for cat_name in DEFAULT_CAT_NAMES - ] - ), - ), - ] -) - - -def _torcharrow_row_mapper( - row: List[str], -) -> Tuple[int, Tuple[int, ...], Tuple[int, ...]]: - # TODO: Fix safe_cast type annotation - label = int(safe_cast(row[0], int, 0)) - dense = tuple( - (int(safe_cast(row[i], int, 0)) for i in range(1, 1 + INT_FEATURE_COUNT)) - ) - sparse = tuple( - ( - int(safe_cast(row[i], str, "0") or "0", 16) - for i in range( - 1 + INT_FEATURE_COUNT, 1 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT - ) - ) - ) - # TorchArrow doesn't support uint32, but we can save memory - # by not using int64. Numpy will automatically handle sparse values >= 2 ** 31. - sparse = tuple(np.array(sparse, dtype=np.int32).tolist()) - - return (label, dense, sparse) - - -def criteo_dataframes_from_tsv( - paths: Union[str, Iterable[str]], - *, - batch_size: int = 128, -) -> IterDataPipe: - """ - Load Criteo dataset (Kaggle or Terabyte) as TorchArrow DataFrame streams from TSV file(s) - - This implementaiton is inefficient and is used for prototype and test only. - - Args: - paths (str or Iterable[str]): local paths to TSV files that constitute - the Kaggle or Criteo 1TB dataset. - - Example:: - - datapipe = criteo_dataframes_from_tsv( - ["/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv"] - ) - for df in datapipe: - print(df) - """ - if isinstance(paths, str): - paths = [paths] - - datapipe = CriteoIterDataPipe(paths, row_mapper=_torcharrow_row_mapper) - datapipe = dp.iter.Batcher(datapipe, batch_size) - datapipe = dp.iter.Mapper(datapipe, lambda batch: ta.dataframe(batch, dtype=DTYPE)) - - return datapipe diff --git a/examples/datasets/tests/test_criteo_dataframes.py b/examples/datasets/tests/test_criteo_dataframes.py deleted file mode 100644 index 18cd5079b..000000000 --- a/examples/datasets/tests/test_criteo_dataframes.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import contextlib - -import torcharrow as ta -from torch.utils.data import IterDataPipe -from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest - -from ..criteo_dataframes import criteo_dataframes_from_tsv - - -class CriteoDataFramesTest(CriteoTest): - BATCH_SIZE = 4 - - def test_single_file(self) -> None: - with self._create_dataset_tsv() as dataset_pathname: - dataset = criteo_dataframes_from_tsv( - dataset_pathname, batch_size=self.BATCH_SIZE - ) - - self._validate_dataset(dataset, 10) - - def test_multiple_files(self) -> None: - with contextlib.ExitStack() as stack: - pathnames = [ - stack.enter_context(self._create_dataset_tsv()) for _ in range(3) - ] - dataset = criteo_dataframes_from_tsv(pathnames, batch_size=self.BATCH_SIZE) - - self._validate_dataset(dataset, 30) - - def _validate_dataset( - self, dataset: IterDataPipe, expected_total_length: int - ) -> None: - last_batch = False - total_length = 0 - - for df in dataset: - self.assertFalse(last_batch) - self.assertTrue(isinstance(df, ta.DataFrame)) - self.assertLessEqual(len(df), self.BATCH_SIZE) - - total_length += len(df) - if len(df) < self.BATCH_SIZE: - last_batch = True - - self._validate_dataframe(df) - self.assertEqual(total_length, expected_total_length) - - def _validate_dataframe(self, df: ta.DataFrame, train: bool = True) -> None: - if train: - self.assertEqual(len(df.columns), 3) - labels = df["labels"] - for label_val in labels: - self.assertTrue( - self.LABEL_VAL_RANGE[0] <= label_val <= self.LABEL_VAL_RANGE[1] - ) - else: - self.assertEqual(len(df.columns), 2) - - # Validations for both train and test - dense_features = df["dense_features"] - for idx in range(self.INT_FEATURE_COUNT): - int_vals = dense_features[f"int_{idx}"] - for int_val in int_vals: - self.assertTrue( - self.INT_VAL_RANGE[0] <= int_val <= self.INT_VAL_RANGE[1] - ) - - sparse_features = df["sparse_features"] - for idx in range(self.CAT_FEATURE_COUNT): - cat_vals = sparse_features[f"cat_{idx}"] - for cat_val in cat_vals: - # stored as int32 - self.assertTrue(-(2**31) <= cat_val <= 2**31 - 1) diff --git a/examples/golden_training/tests/test_train_dlrm.py b/examples/golden_training/tests/test_train_dlrm.py index 691d52937..25e1cc7fc 100644 --- a/examples/golden_training/tests/test_train_dlrm.py +++ b/examples/golden_training/tests/test_train_dlrm.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os import tempfile import unittest diff --git a/examples/golden_training/train_dlrm.py b/examples/golden_training/train_dlrm.py index 511b71bce..55514c53d 100644 --- a/examples/golden_training/train_dlrm.py +++ b/examples/golden_training/train_dlrm.py @@ -5,12 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os from typing import List, Optional import torch from torch import distributed as dist from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) from torch.utils.data import IterableDataset from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset @@ -25,8 +30,8 @@ from torchrec.models.dlrm import DLRM, DLRMTrain from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter from torchrec.optim.rowwise_adagrad import RowWiseAdagrad from tqdm import tqdm @@ -132,7 +137,7 @@ def train( ) non_fused_optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), + dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.Adagrad(params, lr=learning_rate), ) # Overlap comm/compute/device transfer during training through train_pipeline diff --git a/examples/golden_training/train_dlrm_data_parallel.py b/examples/golden_training/train_dlrm_data_parallel.py new file mode 100644 index 000000000..0d57f9617 --- /dev/null +++ b/examples/golden_training/train_dlrm_data_parallel.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +from typing import Iterator, List, Optional, Tuple + +import torch +from torch import distributed as dist, nn +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torch.utils.data import IterableDataset +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.distributed import TrainPipelineSparseDist +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fbgemm_qcomm_codec import ( + CommType, + get_qcomm_codecs_registry, + QCommsConfig, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterConstraints, ShardingPlan +from torchrec.models.dlrm import DLRM, DLRMTrain +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from tqdm import tqdm + + +def _get_random_dataset( + num_embeddings: int, + batch_size: int = 32, +) -> IterableDataset: + return RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=batch_size, + hash_size=num_embeddings, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ) + + +@record +def main() -> None: + train() + + +def train( + num_embeddings: int = 1024**2, + embedding_dim: int = 128, + dense_arch_layer_sizes: Optional[List[int]] = None, + over_arch_layer_sizes: Optional[List[int]] = None, + learning_rate: float = 0.1, + num_iterations: int = 1000, + qcomm_forward_precision: Optional[CommType] = CommType.FP16, + qcomm_backward_precision: Optional[CommType] = CommType.BF16, +) -> None: + """ + Duplicate of train_dlrm.py, but manually forces one table to be data_parallel. + We then optimize this table with RWAdagrad, and optimize the rest of the dense params (i.e dense, inter and over archs) with vanilla Adagrad. + + Constructs and trains a DLRM model (using random dummy data). Each script is run on each process (rank) in SPMD fashion. + The embedding layers will be sharded across available ranks + + qcomm_forward_precision: Compression used in forwards pass. FP16 is the recommended usage. INT8 and FP8 are in development, but feel free to try them out. + qcomm_backward_precision: Compression used in backwards pass. We recommend using BF16 to ensure training stability. + + The effects of quantized comms will be most apparent in large training jobs across multiple nodes where inter host communication is expensive. + """ + if dense_arch_layer_sizes is None: + dense_arch_layer_sizes = [64, embedding_dim] + if over_arch_layer_sizes is None: + over_arch_layer_sizes = [64, 1] + + # Init process_group , device, rank, backend + rank = int(os.environ["LOCAL_RANK"]) + if torch.cuda.is_available(): + device: torch.device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + device: torch.device = torch.device("cpu") + backend = "gloo" + dist.init_process_group(backend=backend) + + # Construct DLRM module + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + ] + dlrm_model = DLRM( + embedding_bag_collection=EmbeddingBagCollection( + tables=eb_configs, device=torch.device("meta") + ), + dense_in_features=len(DEFAULT_INT_NAMES), + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=device, + ) + train_model = DLRMTrain(dlrm_model) + # Optional: force some tables (table 10) to be data_parallel through planner ParameterConstraints + planner = EmbeddingShardingPlanner( + constraints={ + "t_cat_10": ParameterConstraints(compute_kernels=["dense"]), + } + ) + plan: ShardingPlan = planner.collective_plan(train_model) + + apply_optimizer_in_backward( + RowWiseAdagrad, + train_model.model.sparse_arch.parameters(), + {"lr": learning_rate}, + ) + qcomm_codecs_registry = ( + get_qcomm_codecs_registry( + qcomms_config=QCommsConfig( + # pyre-ignore + forward_precision=qcomm_forward_precision, + # pyre-ignore + backward_precision=qcomm_backward_precision, + ) + ) + if backend == "nccl" + else None + ) + sharder = EmbeddingBagCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry) + + model = DistributedModelParallel( + plan=plan, + module=train_model, + device=device, + # pyre-ignore + sharders=[sharder], + ) + + # non fused (dense) embeddings are data parallel and use RowWiseAdagrad optimizer + non_fused_embedding_optimizer = KeyedOptimizerWrapper( + dict( + in_backward_optimizer_filter( + model.module.model.sparse_arch.named_parameters() + ) + ), + lambda params: RowWiseAdagrad(params, lr=learning_rate), + ) + + def dense_filter( + named_parameters: Iterator[Tuple[str, nn.Parameter]] + ) -> Iterator[Tuple[str, nn.Parameter]]: + for fqn, param in named_parameters: + if "sparse" not in fqn: + yield fqn, param + + # DLRM dense (over, inter and dense archs) are optimized with vanilla Adagrad + dense_optimizer = KeyedOptimizerWrapper( + dict(dense_filter(model.named_parameters())), + lambda params: torch.optim.Adagrad(params, lr=learning_rate), + ) + combined_optimizer = CombinedOptimizer( + [ + ("non_fused_embedding_optim", non_fused_embedding_optimizer), + ("dense_optim", dense_optimizer), + ("fused_embedding_optim", model.fused_optimizer), + ] + ) + # Overlap comm/compute/device transfer during training through train_pipeline + train_pipeline = TrainPipelineSparseDist( + model, + combined_optimizer, + device, + ) + + # train model + train_iterator = iter( + _get_random_dataset( + num_embeddings=num_embeddings, + ) + ) + for _ in tqdm(range(int(num_iterations)), mininterval=5.0): + train_pipeline.progress(train_iterator) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md similarity index 100% rename from examples/inference/README.md rename to examples/inference_legacy/README.md diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py similarity index 100% rename from examples/inference/dlrm_client.py rename to examples/inference_legacy/dlrm_client.py diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py similarity index 100% rename from examples/inference/dlrm_packager.py rename to examples/inference_legacy/dlrm_packager.py diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py similarity index 96% rename from examples/inference/dlrm_predict.py rename to examples/inference_legacy/dlrm_predict.py index c36bb613d..e2bdcef9d 100644 --- a/examples/inference/dlrm_predict.py +++ b/examples/inference_legacy/dlrm_predict.py @@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module: EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=self.model_config.embedding_dim, - num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx] - if self.model_config.num_embeddings is None - else self.model_config.num_embeddings, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate( diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py similarity index 94% rename from examples/inference/dlrm_predict_single_gpu.py rename to examples/inference_legacy/dlrm_predict_single_gpu.py index 753cb0dcc..ba5323247 100644 --- a/examples/inference/dlrm_predict_single_gpu.py +++ b/examples/inference_legacy/dlrm_predict_single_gpu.py @@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module: EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=self.model_config.embedding_dim, - num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx] - if self.model_config.num_embeddings is None - else self.model_config.num_embeddings, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate( diff --git a/examples/nvt_dataloader/nvt_binary_dataloader.py b/examples/nvt_dataloader/nvt_binary_dataloader.py index 2286b5610..f88ff1feb 100644 --- a/examples/nvt_dataloader/nvt_binary_dataloader.py +++ b/examples/nvt_dataloader/nvt_binary_dataloader.py @@ -94,7 +94,8 @@ def __getitem__(self, idx: int): """Numerical features are returned in the order they appear in the channel spec section For performance reasons, this is required to be the order they are saved in, as specified by the relevant chunk in source spec. - Categorical features are returned in the order they appear in the channel spec section""" + Categorical features are returned in the order they appear in the channel spec section + """ if idx >= self._num_entries: raise IndexError() diff --git a/examples/nvt_dataloader/train_torchrec.py b/examples/nvt_dataloader/train_torchrec.py index ddf53c89b..abdf3a67d 100644 --- a/examples/nvt_dataloader/train_torchrec.py +++ b/examples/nvt_dataloader/train_torchrec.py @@ -19,8 +19,6 @@ import torchrec import torchrec.distributed as trec_dist import torchrec.optim as trec_optim - -from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType from nvt_binary_dataloader import NvtBinaryDataloader from pyre_extensions import none_throws from torchrec import EmbeddingBagCollection @@ -40,6 +38,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter def parse_args(argv: List[str]) -> argparse.Namespace: @@ -209,9 +208,11 @@ def main(argv: List[str]): EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=args.embedding_dim, - num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx] - if num_embeddings_per_feature is not None - else args.num_embeddings, + num_embeddings=( + none_throws(num_embeddings_per_feature)[feature_idx] + if num_embeddings_per_feature is not None + else args.num_embeddings + ), feature_names=[feature_name], ) for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) @@ -233,9 +234,9 @@ def main(argv: List[str]): train_model = fuse_embedding_optimizer( train_model, - optimizer_type=torchrec.optim.RowWiseAdagrad - if args.adagrad - else torch.optim.SGD, + optimizer_type=( + torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD + ), optimizer_kwargs={"learning_rate": args.learning_rate}, device=torch.device("meta"), ) @@ -270,10 +271,12 @@ def main(argv: List[str]): ) non_fused_optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), - lambda params: torch.optim.Adagrad(params, lr=args.learning_rate) - if args.adagrad - else torch.optim.SGD(params, lr=args.learning_rate), + dict(in_backward_optimizer_filter(model.named_parameters())), + lambda params: ( + torch.optim.Adagrad(params, lr=args.learning_rate) + if args.adagrad + else torch.optim.SGD(params, lr=args.learning_rate) + ), ) opt = trec_optim.keyed.CombinedOptimizer( diff --git a/examples/ray/train_torchrec.py b/examples/ray/train_torchrec.py index 2406c6796..4ee5fdeb9 100644 --- a/examples/ray/train_torchrec.py +++ b/examples/ray/train_torchrec.py @@ -24,6 +24,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter from tqdm import tqdm @@ -110,7 +111,7 @@ def train( # Overlap comm/compute/device transfer during training through train_pipeline non_fused_optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), + dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.Adagrad(params, lr=learning_rate), ) train_pipeline = TrainPipelineSparseDist( diff --git a/examples/retrieval/data/dataloader.py b/examples/retrieval/data/dataloader.py index 5727910c7..39e2ca1c4 100644 --- a/examples/retrieval/data/dataloader.py +++ b/examples/retrieval/data/dataloader.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from torch.utils.data import DataLoader from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES from torchrec.datasets.random import RandomRecDataset diff --git a/examples/retrieval/knn_index.py b/examples/retrieval/knn_index.py index a1323f1ca..9db4a6d7d 100644 --- a/examples/retrieval/knn_index.py +++ b/examples/retrieval/knn_index.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Optional, Union import faiss # @manual=//faiss/python:pyfaiss_gpu diff --git a/examples/retrieval/modules/tests/test_two_tower.py b/examples/retrieval/modules/tests/test_two_tower.py index 9ac08b10f..237ae7ee8 100644 --- a/examples/retrieval/modules/tests/test_two_tower.py +++ b/examples/retrieval/modules/tests/test_two_tower.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import unittest diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py index 5ee2b22af..704e02b63 100644 --- a/examples/retrieval/modules/two_tower.py +++ b/examples/retrieval/modules/two_tower.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, List, Mapping, Optional, OrderedDict, Tuple, Union import faiss # @manual=//faiss/python:pyfaiss_gpu @@ -71,12 +73,12 @@ def __init__( embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[ 0 ].embedding_dim - self._feature_names_query: List[ - str - ] = embedding_bag_collection.embedding_bag_configs()[0].feature_names - self._candidate_feature_names: List[ - str - ] = embedding_bag_collection.embedding_bag_configs()[1].feature_names + self._feature_names_query: List[str] = ( + embedding_bag_collection.embedding_bag_configs()[0].feature_names + ) + self._candidate_feature_names: List[str] = ( + embedding_bag_collection.embedding_bag_configs()[1].feature_names + ) self.ebc = embedding_bag_collection self.query_proj = MLP( in_size=embedding_dim, layer_sizes=layer_sizes, device=device @@ -174,6 +176,7 @@ def __init__( layer_sizes: List[int], k: int, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim @@ -186,10 +189,16 @@ def __init__( self.query_ebc = query_ebc self.candidate_ebc = candidate_ebc self.query_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.candidate_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index self.k = k @@ -212,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor: candidates = torch.empty( (batch_size, self.k), device=self.device, dtype=torch.int64 ) + query_embedding = query_embedding.to(torch.float32) # required by faiss self.faiss_index.search(query_embedding, self.k, distances, candidates) # candidate lookup @@ -227,5 +237,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor: # return logit (dot product) return (query_embedding * candidate_embedding).sum(dim=1).squeeze() + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool) -> None: super().load_state_dict(state_dict, strict) diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py index eef1ef455..77952e0fa 100644 --- a/examples/retrieval/tests/test_two_tower_retrieval.py +++ b/examples/retrieval/tests/test_two_tower_retrieval.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch @@ -19,11 +21,12 @@ class InferTest(unittest.TestCase): @skip_if_asan # pyre-ignore[56] @unittest.skipIf( - not torch.cuda.is_available(), - "this test requires a GPU", + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", ) def test_infer_function(self) -> None: infer( embedding_dim=16, layer_sizes=[16], + world_size=2, ) diff --git a/examples/retrieval/tests/test_two_tower_train.py b/examples/retrieval/tests/test_two_tower_train.py index 2b78b8426..8db7684cf 100644 --- a/examples/retrieval/tests/test_two_tower_train.py +++ b/examples/retrieval/tests/test_two_tower_train.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os import tempfile import unittest diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py index b1b4ccb49..08a2de4b6 100644 --- a/examples/retrieval/two_tower_retrieval.py +++ b/examples/retrieval/two_tower_retrieval.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import List, Optional import click @@ -18,16 +20,13 @@ from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner.types import ParameterConstraints from torchrec.distributed.types import ShardingEnv, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor # OSS import try: - # pyre-ignore[21] - # @manual=//torchrec/github/examples/retrieval/data:dataloader - from data.dataloader import get_dataloader # pyre-ignore[21] # @manual=//torchrec/github/examples/retrieval:knn_index @@ -78,6 +77,7 @@ def infer( faiss_device_idx: int = 0, batch_size: int = 32, load_dir: Optional[str] = None, + world_size: int = 2, ) -> None: """ Loads the serialized model and FAISS index from `two_tower_train.py`. @@ -116,6 +116,7 @@ def infer( embedding_dim=embedding_dim, num_embeddings=num_embeddings, feature_names=[feature_name], + data_type=DataType.FP16, ) ebcs.append( EmbeddingBagCollection( @@ -135,7 +136,7 @@ def infer( # pyre-ignore[16] faiss.read_index(f"{load_dir}/faiss.index"), ) - two_tower_sd = torch.load(f"{load_dir}/model.pt") + two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True) retrieval_sd = convert_TwoTower_to_TwoTowerRetrieval( two_tower_sd, [f"t_{two_tower_column_names[0]}"], @@ -156,7 +157,9 @@ def infer( index.train(embeddings) index.add(embeddings) - retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device) + retrieval_model = TwoTowerRetrieval( + index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16 + ) constraints = {} for feature_name in two_tower_column_names: @@ -166,13 +169,16 @@ def infer( ) quant_model = trec_infer.modules.quantize_embeddings( - retrieval_model, dtype=torch.qint8, inplace=True + retrieval_model, + dtype=torch.qint8, + inplace=True, + output_dtype=torch.float16, ) dmp = DistributedModelParallel( module=quant_model, device=device, - env=ShardingEnv.from_local(world_size=2, rank=model_device_idx), + env=ShardingEnv.from_local(world_size=world_size, rank=model_device_idx), init_data_parallel=False, ) if retrieval_sd is not None: diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py index 141c1ca7d..5772a9069 100644 --- a/examples/retrieval/two_tower_train.py +++ b/examples/retrieval/two_tower_train.py @@ -5,25 +5,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import os -from typing import cast, List, Optional +from typing import List, Optional import click import faiss # @manual=//faiss/python:pyfaiss_gpu import faiss.contrib.torch_utils # @manual=//faiss/contrib:faiss_contrib_gpu import torch -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from torch import distributed as dist, nn +from torch import distributed as dist +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) from torchrec import inference as trec_infer from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES from torchrec.distributed import TrainPipelineSparseDist -from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.types import ModuleSharder from torchrec.inference.state_dict_transform import ( state_dict_gather, state_dict_to_device, @@ -31,6 +32,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -108,7 +110,6 @@ def train( layer_sizes = [128, 64] rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) if torch.cuda.is_available(): device: torch.device = torch.device(f"cuda:{rank}") backend = "nccl" @@ -138,36 +139,14 @@ def train( device=device, ) two_tower_train_task = TwoTowerTrainTask(two_tower_model) - - fused_params = { - "learning_rate": learning_rate, - "optimizer": EmbOptimType.ROWWISE_ADAGRAD, - } - sharders = cast( - List[ModuleSharder[nn.Module]], - [EmbeddingBagCollectionSharder(fused_params=fused_params)], - ) - - # TODO: move pg to the EmbeddingShardingPlanner (out of collective_plan) and make optional - # TODO: make Topology optional argument to EmbeddingShardingPlanner - # TODO: give collective_plan a default sharders - # TODO: once this is done, move defaults out of DMP and just get from ShardingPlan (eg _sharding_map should not exist - just use the plan) - plan = EmbeddingShardingPlanner( - topology=Topology( - world_size=world_size, - compute_device=device.type, - ), - ).collective_plan( - module=two_tower_model, - sharders=sharders, - # pyre-fixme[6]: For 3rd param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - pg=dist.GroupMember.WORLD, + apply_optimizer_in_backward( + RowWiseAdagrad, + two_tower_train_task.two_tower.ebc.parameters(), + {"lr": learning_rate}, ) model = DistributedModelParallel( module=two_tower_train_task, device=device, - plan=plan, ) optimizer = KeyedOptimizerWrapper( diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md deleted file mode 100644 index cc2613a38..000000000 --- a/examples/torcharrow/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Description - -This shows a prototype of integrating a TorchRec based training loop utilizing TorchArrow's on-the-fly preprocessing. The main motivation is to show the utilization of TorchArrow's specialized domain UDFs. Here we use `bucketize`, `firstx`, as well as `sigrid_hash` to do some last-mile preprocessing over the criteo dataset in parquet format. More recommendation domain functions can be found at [torcharrow.functional Doc](https://pytorch.org/torcharrow/beta/functional.html#recommendation-operations). - -These three UDFs are extensively used in Meta's RecSys preprocessing stack. Notably, these UDFs can be used to easily adjust the proprocessing script to any model changes. For example, if we wish to change the size of our embedding tables, without sigrid_hash, we would need to rerun a bulk offline preproc to ensure that all indicies are within bounds. Bucketize lets us easily convert dense features into sparse features, with flexibility of what the bucket borders are. firstx lets us easily prune sparse ids (note, that this doesn't provide any functional features, but is in the preproc script as demonstration). - - -## Installations and Usage - -Download the criteo tsv files (see the README in the main DLRM example). Use the nvtabular script (in torchrec/datasets/scripts/nvt/) to convert the TSV files to parquet. - -To start, install torcharrow-nightly and torchdata: -``` -pip install --pre torcharrow -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -pip install torchdata -``` -You can also build TorchArrow from source, following https://github.com/pytorch/torcharrow - -Usage - -``` -torchx run -s local_cwd dist.ddp -j 1x4 --script examples/torcharrow/run.py -- --parquet_directory /home/criteo_parquet -``` - -The preprocessing logic is in ```dataloader.py``` - -## Extentions/Future work - -* We will eventually integrate with the up and coming DataLoader2, which will allow us to utilize a prebuilt solution to collate our dataframe batches to dense tensors, or TorchRec's KeyedJaggedTensors (rather than doing this by hand). -* Building an easier solution/more performant to convert parquet -> IterableDataPipe[torcharrow.DataFrame] (aka ArrowDataPipe). Also currently batch sizes are not available. -* Some functional abilities are not yet available (such as make_named_row, etc). -* Support collation/conversion for ArrowDataPipe -* More RecSys UDFs to come! Please let us know if you have any suggestions. diff --git a/examples/torcharrow/dataloader.py b/examples/torcharrow/dataloader.py deleted file mode 100644 index 3cca4666d..000000000 --- a/examples/torcharrow/dataloader.py +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torcharrow as ta -import torcharrow.dtypes as dt -import torcharrow.pytorch as tap -from torch.utils.data import DataLoader -from torcharrow import functional -from torchdata.datapipes.iter import FileLister -from torchrec.datasets.criteo import ( - DEFAULT_CAT_NAMES, - DEFAULT_INT_NAMES, - DEFAULT_LABEL_NAME, -) - -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor - - -class _JaggedTensorConversion(tap.TensorConversion): - # pyre-fixme[14]: `to_tensor` overrides method defined in `TensorConversion` - # inconsistently. - def to_tensor(self, df: ta.DataFrame): - - kjt_keys = df.columns - kjt_values = [] - kjt_lengths = [] - for row in df: - for idx, _column in enumerate(df.columns): - value = row[idx] - kjt_values.extend(value) - kjt_lengths.append(len(value)) - kjt = KeyedJaggedTensor.from_lengths_sync( - keys=kjt_keys, - values=torch.tensor(kjt_values), - lengths=torch.tensor(kjt_lengths), - ) - return kjt - - -class _Scalar(tap.TensorConversion): - def to_tensor(self, df: ta.DataFrame): - labels = torch.tensor(df) - return labels - - -def get_dataloader( - parquet_directory, world_size, rank, num_embeddings=4096, salt=0, batch_size=16 -): - source_dp = FileLister(parquet_directory, masks="*.parquet") - # TODO support batch_size for load_parquet_as_df. - # TODO use OSSArrowDataPipe once it is ready - parquet_df_dp = source_dp.load_parquet_as_df() - - def preproc(df, max_idx=num_embeddings, salt=salt): - - for feature_name in DEFAULT_INT_NAMES: - df[feature_name] = df[feature_name].fill_null(0) - for feature_name in DEFAULT_CAT_NAMES: - df[feature_name] = df[feature_name].fill_null(0) - df[feature_name] = df[feature_name].cast(dt.int64) - - # construct a sprase index from a dense one - df["bucketize_int_0"] = functional.bucketize(df["int_0"], [0.5, 1.0, 1.5]).cast( - dt.int64 - ) - - # flatten several columns into one - df["dense_features"] = ta.dataframe( - {int_name: df[int_name] for int_name in DEFAULT_INT_NAMES} - ) - - df["dense_features"] = (df["dense_features"] + 3).log() - - for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]: - # hash our embedding index into our embedding tables - df[cat_name] = functional.sigrid_hash(df[cat_name], salt, max_idx) - df[cat_name] = functional.array_constructor(df[cat_name]) - df[cat_name] = functional.firstx(df[cat_name], 1) - - df["sparse_features"] = ta.dataframe( - { - cat_name: df[cat_name] - for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"] - } - ) - - df = df[["dense_features", "sparse_features", DEFAULT_LABEL_NAME]] - - return df - - parquet_df_dp = parquet_df_dp.map(preproc).sharding_filter() - parquet_df_dp.apply_sharding(world_size, rank) - - def criteo_collate(df): - dense_features, kjt, labels = df.to_tensor( - { - "dense_features": tap.rec.Dense(batch_first=True), - "sparse_features": _JaggedTensorConversion(), - "label": _Scalar(), - } - ) - - return dense_features, kjt, labels - - return DataLoader( - parquet_df_dp, - batch_size=None, - collate_fn=criteo_collate, - drop_last=False, - pin_memory=True, - ) diff --git a/examples/torcharrow/requirements.txt b/examples/torcharrow/requirements.txt deleted file mode 100644 index 3e3c4c950..000000000 --- a/examples/torcharrow/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -click -torcharrow -torchdata diff --git a/examples/torcharrow/run.py b/examples/torcharrow/run.py deleted file mode 100644 index 1a3dca47d..000000000 --- a/examples/torcharrow/run.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os - -import click -import dataloader as torcharrow_dataloader -import torch -import torch.distributed as dist -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from torch.distributed.elastic.multiprocessing.errors import record -from torchrec import EmbeddingBagCollection -from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, INT_FEATURE_COUNT -from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder -from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.models.dlrm import DLRM -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.optim.keyed import KeyedOptimizerWrapper - - -@record -@click.command() -@click.option("--batch_size", default=256) -@click.option("--num_embeddings", default=2048) -@click.option("--sigrid_hash_salt", default=0) -@click.option("--parquet_directory", default="/data/criteo_preproc") -def main( - batch_size, - num_embeddings, - sigrid_hash_salt, - parquet_directory, -) -> None: - rank = int(os.environ["LOCAL_RANK"]) - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - backend = "nccl" - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - backend = "gloo" - print( - "\033[92m" - + f"WARNING: Running in CPU mode. cuda availablility {torch.cuda.is_available()}." - ) - - dist.init_process_group(backend=backend) - - world_size = dist.get_world_size() - - dataloader = torcharrow_dataloader.get_dataloader( - parquet_directory, - world_size, - rank, - batch_size=batch_size, - num_embeddings=num_embeddings, - salt=sigrid_hash_salt, - ) - it = iter(dataloader) - - model = DLRM( - embedding_bag_collection=EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name=f"table_{cat_name}", - embedding_dim=64, - num_embeddings=num_embeddings, - feature_names=[cat_name], - ) - for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"] - ], - device=torch.device("meta"), - ), - dense_in_features=INT_FEATURE_COUNT, - dense_arch_layer_sizes=[64], - over_arch_layer_sizes=[32, 1], - dense_device=device, - ) - - fused_params = { - "learning_rate": 0.02, - "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, - } - - sharded_model = DistributedModelParallel( - module=model, - device=device, - sharders=[ - EmbeddingBagCollectionSharder(fused_params=fused_params), - ], - ) - - optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), - lambda params: torch.optim.SGD(params, lr=0.01), - ) - - loss_fn = torch.nn.BCEWithLogitsLoss() - - print_example = dist.get_rank() == 0 - for (dense_features, kjt, labels) in it: - if print_example: - print("Example dense_features", dense_features) - print("Example KJT input", kjt) - print_example = False - - dense_features = dense_features.to(device) - kjt = kjt.to(device) - labels = labels.to(device) - - optimizer.zero_grad() - - preds = sharded_model(dense_features, kjt) - loss = loss_fn(preds.squeeze(), labels.squeeze()) - loss.sum().backward() - - optimizer.step() - - print("\033[92m" + "DLRM run with torcharrow last-mile preprocessing finished!") - - -if __name__ == "__main__": - main() diff --git a/examples/transfer_learning/train_from_pretrained_embedding.py b/examples/transfer_learning/train_from_pretrained_embedding.py index 6d659f67d..ac7db53b4 100644 --- a/examples/transfer_learning/train_from_pretrained_embedding.py +++ b/examples/transfer_learning/train_from_pretrained_embedding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copyreg import io import os @@ -222,7 +224,7 @@ def main() -> None: dist.barrier() -if __name__ == "__main__": +def invoke_main() -> None: lc = pet.LaunchConfig( min_nodes=1, max_nodes=1, @@ -234,3 +236,7 @@ def main() -> None: monitor_interval=1, ) pet.elastic_launch(lc, entrypoint=main)() + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/install-requirements.txt b/install-requirements.txt new file mode 100644 index 000000000..ed3c6aced --- /dev/null +++ b/install-requirements.txt @@ -0,0 +1,6 @@ +fbgemm-gpu +tensordict +torchmetrics==1.0.3 +tqdm +pyre-extensions +iopath diff --git a/requirements.txt b/requirements.txt index daa72b622..6b17aeac6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,43 +1,19 @@ -arrow -attrs -certifi -charset-normalizer +black cmake -Cython -distro -docker -docstring-parser -fbgemm-gpu-nightly -filelock -fsspec -hypothesis -idna +fbgemm-gpu +hypothesis==6.70.1 iopath -Jinja2 -MarkupSafe -mypy-extensions -ninja numpy -packaging pandas -portalocker -pyarrow -pyDeprecate -pyparsing -pyre-extensions==0.0.27 -python-dateutil -pytz -PyYAML -requests +pyre-extensions scikit-build -six -sortedcontainers -tabulate -torchmetrics +tensordict +torchmetrics==1.0.3 torchx tqdm -typing-inspect -typing_extensions -urllib3 usort -websocket-client +parameterized + +# for tests +# https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3 +expecttest diff --git a/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md new file mode 100644 index 000000000..09e9069fb --- /dev/null +++ b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md @@ -0,0 +1,146 @@ +# RFC: Flexible Collision Free Embedding Table + +| Status | Published | +| :---- | :---- | +| Author(s) | Emma Lin, Joe Wang, Kaustubh Vartak, Dennis van der Staay, Huanyu He| +| Updated | 04-13-2025 | + +## Motivation + +PyTorch and FBGemm utilize fixed-size continuous embedding memory to handle sparse features, employing a uniform hash function to map raw IDs to a limited index space. This approach has resulted in a high rate of hash collisions and inefficient storage. + +The Zero Collision Hash (ZCH) technique allows models to train individual IDs uniquely, leading to notable enhancements in model freshness and user engagement. When properly configured, we've seen improvements in freshness late stage ranking and early stage ranking models. However, the remapping-based ZCH solutions have presented several scalability challenges. + +## Objective + +This RFC proposes a new embedding table format that natively supports collision-free features, enhancing the scalability and usability of embedding tables. + +The approach involves utilizing an extremely large hash size (e.g., 2^63) to map raw IDs to a vast ID space, significantly reducing the likelihood of collisions during hashing. + +Instead of mapping IDs to a fixed table size (e.g., 1 billion rows), this format reserves memory spaces for each ID, effectively eliminating collisions and achieving native collision-free results. + +Notably, this design eliminates the need for a remapper to find available slots for colliding IDs, streamlining the process and improving overall efficiency, embedding scalability and usability. + +## Design Proposal + +### Bucket Aware Sharding and Resharding Algorithm + +To address the dynamic nature of sparse IDs and their distribution, we propose introducing a bucket concept. We will provide a large default bucket number configuration and an extremely large table size. The mapping from ID to bucket ID can be done in two ways: + +* Interleave-based: bucket\_id \= hash\_id % total\_bucket\_number + This approach is similar to the sharding solution used in MPZCH, allowing for seamless migration without requiring ID resharding. +* Chunk-based: + * bucket\_size \= table\_size / total\_bucket\_number, + * bucket\_id \= id / bucket\_size + + +Both options will be configurable. + +After sharding IDs into buckets, we will distribute the buckets sequentially across trainers. For example, with 1000 buckets and 100 trainers, each trainer would own 10 consecutive buckets. + +T1: b0-b9 + +T2: b10-b19 + +... + +When resharding is necessary, we will move buckets around instead of individual rows. For instance, reducing the number of trainers from 100 to 50 would result in each trainer owning 20 consecutive buckets. + +T1: b0-b19 + +T2: b20-b39 + +... + +The row count within each bucket can vary from 0 to the maximum bucket size, depending on the ID distribution. However, using a larger bucket number should lead to a more even distribution of IDs. + +#### Benefit + +The bucket number remains unchanged when scaling the model or adjusting the number of trainers, making it easier to move buckets around without introducing additional overhead to reshard every ID's new location. + +Resharding every ID can be an expensive operation, especially for large tables (e.g., over 1 billion rows). + +### Bucketized Torchrec Sharding and Input Distribution + +Based on the proposed sharding definition, the TorchRec sharder needs to be aware of the bucket configuration from the embedding table. + +Input distribution needs to take into account the bucket configuration, and then distribute input to the corresponding trainer. + +Here is the code [reference](https://github.com/pytorch/torchrec/blob/f36d26db4432bd7335f6df9e7e75d8643b7ffb04/torchrec/distributed/sharding/rw_sequence_sharding.py#L129C16-L129C36). + +### FBGemm Operators Optimization for Collision Free EmbeddingTable + +FBGEMM\_GPU (FBGEMM GPU Kernels Library) is highly optimized for fixed sized tables, with continuous memory space, including in HBM, UVM or CPU memory. +However, when we apply collision free idea, there are several assumptions of FBGEMM are broken: + +* Table size is not fixed. It could grow over training iterations or shrink after eviction. +* Embedding lookup input is not embedding offset anymore, so we need to maintain an explicit mapping from input to the embedding value. +* Table size could exceed memory limitation, but actual trained id size is finite, so we cannot preserve memory based on table configuration. + +We’re looking for an optimized K/V FBGemm version to support flexible memory management. + +#### Training Operator (from [Joe Wang](mailto:wangj@meta.com)) + +* Optimized CPU memory management with K/V format + * Reduce memory fragmentation + * efficient memory utilization + * Fast lookup performance + * Flexible eviction policy +* Collision free LXU cache to avoid extra memory copy from CPU to UVM and UVM memory read during embedding lookup. + * The current LXU cache used by FBGemm could cause id collisions. When collision happens, prefetch won’t be able to load embedding value to HBM, which will fallback to UVM memory read during embedding lookup. This can impact training QPS in two ways: + * Introduce one extra CPU memory copy, since data needs to be copied from CPU to UVM, since the CPU embedding data in k/v format might not be accessible from the GPU card. + * Introduced H2D data copy in embedding lookup. + +[Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py) is the code reference of a k/v format SSD offloading operator, and provides a backend interface to hook up other k/v store implementations. +We propose to implement a new k/v store backend, to decouple SSD and rocksDB dependency, but the SSD backend operator can be used for extra large embeddings which do not fit into host memory. + +#### Inference Operator + +On top of training operators functionality, the inference operator needs to support dequantization from nbit int value after embedding is queried out from the embedding store. We’d like to have a fast inference operator with additional requirements: + +* Optimized CPU memory management with k/v format +* Collision free LXU cache +* Keep the fast nbit Int data format support, with pooling, dequantization features. +* Support decoupled large tensor loading and reset, to allow model state in-place update. + [Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py) is the current inference operator which only supports offset based access for now. + +### Enhancing TorchRec's Prefetch Pipeline for Synchronized Training + +TorchRec offers multiple training [pipelines](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py) that help overlap communication and computation, reducing embedding lookup latency and improving training QPS. Specifically, PrefetchTrainPipelineSparseDist supports synchronized training, while TrainPipelineSemiSync supports asynchronous training. + +Our goal is to further enhance the prefetch pipeline to enable prefetching multiple training data batches ahead of time while maintaining synchronized training and avoiding step misalignment. + +**Design Principles:** + +* Zero Collision Cache: Enable zero collision cache in GPU to cache embeddings for multiple batches without collision-induced cache misses. +* Forward Pass: Perform embedding lookups only in HBM during forward passes to improve performance. +* Backward Pass: Update embeddings in HBM synchronously during backward passes to ensure all embedding lookup results are up-to-date. +* Asynchronous UVM Embedding Update: Update UVM embeddings asynchronously after embedding updates in HBM. + +**End Goal:** + +Achieve on-par performance with GPU HBM-based training while scaling up sparse embedding tables in CPU memory. + +### Warm Start and Transfer Learning with Collision-Free Embedding Tables + +Warm start, or transfer learning, is a widely used technique in industry to facilitate model iteration while maintaining on-par topline metrics. However, the introduction of collision-free embedding tables poses challenges to this process. + +With the ability to increase table size and feature hash size to \~2^60, collision-free embedding tables offer improved efficiency. However, since id hash size is changed and sharding solution is different, when resuming training from a non-zero collision table to a zero-collision table, the redistribution of IDs across trainers becomes computationally expensive. + +#### Solution: Backfilling Embedding Values + +To address this challenge, we propose the following solution: + +* Create Two Embedding Tables: One table is copied from the old model, and the other is the new table. + +* Freeze Old Embedding Table: The old embedding table is set to read-only mode in the new model. + +* Training Forward Loop: During the forward pass, if an embedding value is not found in the new table, the underlying embedding lookup operator searches the old embedding table for a pre-trained value. + + * This requires an additional all-to-all call using TorchRec to retrieve the old embedding value. + + * We need to leverage the prefetch process to hide the extra latency. + +* Stop Backfilling Process: Once the new table is sufficiently populated, the backfilling process can be stopped. + +This approach enables efficient warm start and transfer learning with collision-free embedding tables, reducing the computational overhead associated with ID redistribution. diff --git a/setup.py b/setup.py index 63d3aa30d..34987b905 100644 --- a/setup.py +++ b/setup.py @@ -7,97 +7,114 @@ import argparse import os -import random -import re +import subprocess import sys -from datetime import date +from pathlib import Path from typing import List from setuptools import find_packages, setup +ROOT_DIR = Path(__file__).parent.resolve() -def get_version(): - # get version string from version.py - # TODO: ideally the version.py should be generated when setup is run - version_file = os.path.join(os.path.dirname(__file__), "version.py") - version_regex = r"__version__ = ['\"]([^'\"]*)['\"]" - with open(version_file, "r") as f: - version = re.search(version_regex, f.read(), re.M).group(1) - return version +def _get_version(): + try: + cmd = ["git", "rev-parse", "HEAD"] + sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode("ascii").strip() + except Exception: + sha = None -def get_nightly_version(): - today = date.today() - return f"{today.year}.{today.month}.{today.day}" + if "BUILD_VERSION" in os.environ: + version = os.environ["BUILD_VERSION"] + else: + with open(os.path.join(ROOT_DIR, "version.txt"), "r") as f: + version = f.readline().strip() + if sha is not None and "OFFICIAL_RELEASE" not in os.environ: + version += "+" + sha[:7] + + if sha is None: + sha = "Unknown" + return version, sha + + +def _export_version(version, sha): + version_path = ROOT_DIR / "torchrec" / "version.py" + with open(version_path, "w") as fileobj: + fileobj.write("__version__ = '{}'\n".format(version)) + fileobj.write("git_version = {}\n".format(repr(sha))) def parse_args(argv: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser(description="torchrec setup") - parser.add_argument( - "--package_name", - type=str, - default="torchrec", - help="the name of this output wheel", - ) return parser.parse_known_args(argv) def main(argv: List[str]) -> None: args, unknown = parse_args(argv) - # Set up package name and version - name = args.package_name - is_nightly = "nightly" in name - is_test = "test" in name - with open( os.path.join(os.path.dirname(__file__), "README.MD"), encoding="utf8" ) as f: readme = f.read() with open( - os.path.join(os.path.dirname(__file__), "requirements.txt"), encoding="utf8" + os.path.join(os.path.dirname(__file__), "install-requirements.txt"), + encoding="utf8", ) as f: reqs = f.read() install_requires = reqs.strip().split("\n") - version = get_nightly_version() if is_nightly else get_version() - - if not is_nightly: - if "fbgemm-gpu-nightly" in install_requires: - install_requires.remove("fbgemm-gpu-nightly") - install_requires.append("fbgemm-gpu") - - if is_test: - version = (f"0.0.{random.randint(0, 1000)}",) - print(f"-- {name} building version: {version}") - - packages = find_packages(exclude=("*tests",)) + version, sha = _get_version() + _export_version(version, sha) + + print(f"-- torchrec building version: {version}") + + packages = find_packages( + exclude=( + "*tests", + "*test", + "examples", + "*examples.*", + "*benchmarks", + "*build", + "*rfc", + ) + ) sys.argv = [sys.argv[0]] + unknown setup( # Metadata - name=name, + name="torchrec", version=version, author="TorchRec Team", author_email="packages@pytorch.org", - description="Pytorch domain library for recommendation systems", + maintainer="PaulZhang12", + maintainer_email="paulzhan@meta.com", + description="TorchRec: Pytorch library for recommendation systems", long_description=readme, long_description_content_type="text/markdown", url="https://github.com/pytorch/torchrec", license="BSD-3", - keywords=["pytorch", "recommendation systems", "sharding"], - python_requires=">=3.7", + keywords=[ + "pytorch", + "recommendation systems", + "sharding", + "distributed training", + ], + python_requires=">=3.9", install_requires=install_requires, packages=packages, zip_safe=False, # PyPI package information. classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/test/cpp/dynamic_embedding/CMakeLists.txt b/test/cpp/dynamic_embedding/CMakeLists.txt index 056bfd985..53d2c7c02 100644 --- a/test/cpp/dynamic_embedding/CMakeLists.txt +++ b/test/cpp/dynamic_embedding/CMakeLists.txt @@ -14,3 +14,8 @@ add_tde_test(bits_op_test bits_op_test.cpp) add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp) add_tde_test(random_bits_generator_test random_bits_generator_test.cpp) add_tde_test(mixed_lfu_lru_strategy_test mixed_lfu_lru_strategy_test.cpp) +add_tde_test(notification_test notification_test.cpp) + +if (BUILD_REDIS_IO) + add_subdirectory(redis) +endif() diff --git a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp index f8526e575..456d07e1e 100644 --- a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp +++ b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp @@ -12,8 +12,8 @@ namespace torchrec { TEST(TDE, order) { MixedLFULRUStrategy::Record a; - a.time_ = 1; - a.freq_power_ = 31; + a.time = 1; + a.freq_power = 31; uint32_t i32 = a.ToUint32(); ASSERT_EQ(0xF8000001, i32); } @@ -23,42 +23,41 @@ TEST(TDE, MixedLFULRUStrategy_Evict) { { records.emplace_back(); records.back().first = 1; - records.back().second.time_ = 100; - records.back().second.freq_power_ = 2; + records.back().second.time = 100; + records.back().second.freq_power = 2; } { records.emplace_back(); records.back().first = 2; - records.back().second.time_ = 10; - records.back().second.freq_power_ = 2; + records.back().second.time = 10; + records.back().second.freq_power = 2; } { records.emplace_back(); records.back().first = 3; - records.back().second.time_ = 100; - records.back().second.freq_power_ = 1; + records.back().second.time = 100; + records.back().second.freq_power = 1; } { records.emplace_back(); records.back().first = 4; - records.back().second.time_ = 150; - records.back().second.freq_power_ = 2; + records.back().second.time = 150; + records.back().second.freq_power = 2; } size_t offset_{0}; - auto ids = MixedLFULRUStrategy::evict( - [&offset_, - &records]() -> std::optional { + MixedLFULRUStrategy strategy; + auto ids = strategy.evict( + [&offset_, &records]() -> std::optional { if (offset_ == records.size()) { return std::nullopt; } auto record = records[offset_++]; - MixedLFULRUStrategy::lxu_record_t ext_type = - *reinterpret_cast( - &record.second); - return MixedLFULRUStrategy::transformer_record_t{ - .global_id_ = record.first, - .cache_id_ = 0, - .lxu_record_ = ext_type, + lxu_record_t ext_type = + *reinterpret_cast(&record.second); + return record_t{ + .global_id = record.first, + .cache_id = 0, + .lxu_record = ext_type, }; }, 3); @@ -73,12 +72,12 @@ TEST(TDE, MixedLFULRUStrategy_Transform) { constexpr static size_t n_iter = 1000000; MixedLFULRUStrategy strategy; strategy.update_time(10); - MixedLFULRUStrategy::lxu_record_t val; + lxu_record_t val; { val = strategy.update(0, 0, std::nullopt); auto record = reinterpret_cast(&val); - ASSERT_EQ(record->freq_power_, 5); - ASSERT_EQ(record->time_, 10); + ASSERT_EQ(record->freq_power, 5); + ASSERT_EQ(record->time, 10); } uint32_t freq_power_5_cnt = 0; @@ -87,13 +86,13 @@ TEST(TDE, MixedLFULRUStrategy_Transform) { for (size_t i = 0; i < n_iter; ++i) { auto tmp = strategy.update(0, 0, val); auto record = reinterpret_cast(&tmp); - ASSERT_EQ(record->time_, 10); - if (record->freq_power_ == 5) { + ASSERT_EQ(record->time, 10); + if (record->freq_power == 5) { ++freq_power_5_cnt; - } else if (record->freq_power_ == 6) { + } else if (record->freq_power == 6) { ++freq_power_6_cnt; } else { - ASSERT_TRUE(record->freq_power_ == 5 || record->freq_power_ == 6); + ASSERT_TRUE(record->freq_power == 5 || record->freq_power == 6); } } diff --git a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp index a037e7b89..f8ba6e82a 100644 --- a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp +++ b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp @@ -12,8 +12,7 @@ namespace torchrec { TEST(tde, NaiveThreadedIDTransformer_NoFilter) { - using Tag = int32_t; - NaiveIDTransformer> transformer(16); + NaiveIDTransformer> transformer(16); const int64_t global_ids[5] = {100, 101, 100, 102, 101}; int64_t cache_ids[5]; int64_t expected_cache_ids[5] = {0, 1, 0, 2, 1}; @@ -24,8 +23,7 @@ TEST(tde, NaiveThreadedIDTransformer_NoFilter) { } TEST(tde, NaiveThreadedIDTransformer_Full) { - using Tag = int32_t; - NaiveIDTransformer> transformer(4); + NaiveIDTransformer> transformer(4); const int64_t global_ids[5] = {100, 101, 102, 103, 104}; int64_t cache_ids[5]; int64_t expected_cache_ids[5] = {0, 1, 2, 3, -1}; @@ -37,8 +35,7 @@ TEST(tde, NaiveThreadedIDTransformer_Full) { } TEST(tde, NaiveThreadedIDTransformer_Evict) { - using Tag = int32_t; - NaiveIDTransformer> transformer(4); + NaiveIDTransformer> transformer(4); const int64_t global_ids[5] = {100, 101, 102, 103, 104}; int64_t cache_ids[5]; @@ -60,8 +57,7 @@ TEST(tde, NaiveThreadedIDTransformer_Evict) { } TEST(tde, NaiveThreadedIDTransformer_Iterator) { - using Tag = int32_t; - NaiveIDTransformer> transformer(16); + NaiveIDTransformer> transformer(16); const int64_t global_ids[5] = {100, 101, 100, 102, 101}; int64_t cache_ids[5]; int64_t expected_cache_ids[5] = {3, 4, 3, 5, 4}; diff --git a/test/cpp/dynamic_embedding/notification_test.cpp b/test/cpp/dynamic_embedding/notification_test.cpp new file mode 100644 index 000000000..f916678d8 --- /dev/null +++ b/test/cpp/dynamic_embedding/notification_test.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torchrec { +TEST(TDE, notification) { + Notification notification; + std::thread th([&] { notification.done(); }); + notification.wait(); + th.join(); +} +} // namespace torchrec diff --git a/test/cpp/dynamic_embedding/redis/CMakeLists.txt b/test/cpp/dynamic_embedding/redis/CMakeLists.txt new file mode 100644 index 000000000..b1c2f0c7b --- /dev/null +++ b/test/cpp/dynamic_embedding/redis/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +function(add_redis_test NAME) + add_executable(${NAME} ${ARGN}) + target_link_libraries(${NAME} redis_io tde_cpp_objs gtest gtest_main) + add_test(NAME ${NAME} COMMAND ${NAME}) +endfunction() + +# TODO: Need start a empty redis-server +# on 127.0.0.1:6379 before run *redis*_test. +add_redis_test(redis_io_test redis_io_test.cpp) +add_redis_test(url_test url_test.cpp) diff --git a/test/cpp/dynamic_embedding/redis/redis_io_test.cpp b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp new file mode 100644 index 000000000..f5c77c257 --- /dev/null +++ b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torchrec::redis { + +TEST(TDE, redis_Option) { + auto opt = parse_option( + "192.168.3.1:3948/?db=3&&num_threads=2&&timeout=3s&&chunk_size=3000"); + ASSERT_EQ(opt.host, "192.168.3.1"); + ASSERT_EQ(opt.port, 3948); + ASSERT_EQ(opt.db, 3); + ASSERT_EQ(opt.num_io_threads, 2); + ASSERT_EQ(opt.chunk_size, 3000); + ASSERT_EQ(opt.timeout_ms, 3000); + ASSERT_TRUE(opt.prefix.empty()); +} + +TEST(TDE, redis_Option_ParseError) { + ASSERT_ANY_THROW( + parse_option("192.168.3.1:3948/?db=3&&no_opt=3000&&num_threads=2")); + ASSERT_ANY_THROW(parse_option("192.168.3.1:3948/?timeout=3d")); +} + +struct FetchContext { + Notification* notification_; + std::function on_data_; +}; + +TEST(TDE, redis_push_fetch) { + auto opt = parse_option("127.0.0.1:6379"); + Redis redis(opt); + + constexpr static int64_t global_ids[] = {1, 3, 4}; + constexpr static uint32_t os_ids[] = {0}; + constexpr static float params[] = {1, 2, 3, 4, 5, 9, 8, 1}; + constexpr static uint64_t offsets[] = { + 0 * sizeof(float), + 2 * sizeof(float), + 4 * sizeof(float), + 6 * sizeof(float), + 8 * sizeof(float)}; + + Notification notification; + + IOPushParameter push{ + .table_name = "table", + .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]), + .global_ids = global_ids, + .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]), + .optimizer_state_ids = os_ids, + .num_offsets = sizeof(offsets) / sizeof(offsets[0]), + .offsets = offsets, + .data = params, + .on_complete_context = ¬ification, + .on_push_complete = + +[](void* ctx) { + auto* notification = reinterpret_cast(ctx); + notification->done(); + }, + }; + redis.push(push); + + notification.wait(); + + notification.clear(); + + FetchContext ctx{ + .notification_ = ¬ification, + .on_data_ = + [&](uint32_t offset, uint32_t os_id, void* data, uint32_t len) { + ASSERT_EQ(os_id, 0); + uint32_t param_len = 2; + ASSERT_EQ(len, sizeof(float) * param_len); + auto actual = + std::span(reinterpret_cast(data), 2); + + auto expect = std::span( + reinterpret_cast(¶ms[offset * param_len]), 2); + + ASSERT_EQ(expect[0], actual[0]); + ASSERT_EQ(expect[1], actual[1]); + }}; + + IOFetchParameter fetch{ + .table_name = "table", + .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]), + .global_ids = global_ids, + .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]), + .on_complete_context = &ctx, + .on_global_id_fetched = + +[](void* ctx, + uint32_t offset, + uint32_t os_id, + void* data, + uint32_t len) { + auto c = reinterpret_cast(ctx); + c->on_data_(offset, os_id, data, len); + }, + .on_all_fetched = + +[](void* ctx) { + auto c = reinterpret_cast(ctx); + c->notification_->done(); + }}; + redis.fetch(fetch); + notification.wait(); +} +} // namespace torchrec::redis diff --git a/test/cpp/dynamic_embedding/redis/url_test.cpp b/test/cpp/dynamic_embedding/redis/url_test.cpp new file mode 100644 index 000000000..6f37744c7 --- /dev/null +++ b/test/cpp/dynamic_embedding/redis/url_test.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torchrec::url_parser::rules { + +TEST(TDE, url) { + auto url = parse_url("http://github.com/www.qq.com/?a=b&&c=d"); + ASSERT_EQ(url.host, "www.qq.com"); + ASSERT_TRUE(url.param.has_value()); + ASSERT_EQ("a=b&&c=d", url.param.value()); +} + +} // namespace torchrec::url_parser::rules diff --git a/test_installation.py b/test_installation.py index cbef42766..48328a389 100644 --- a/test_installation.py +++ b/test_installation.py @@ -18,6 +18,7 @@ from torchrec.models.dlrm import DLRM from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter if sys.platform not in ["linux", "linux2"]: raise EnvironmentError( @@ -118,7 +119,7 @@ def main(argv: List[str]) -> None: device=device, ) optimizer = KeyedOptimizerWrapper( - dict(model.named_parameters()), + dict(in_backward_optimizer_filter(model.named_parameters())), lambda params: torch.optim.SGD(params, lr=0.01), ) diff --git a/tools/lint/black_linter.py b/tools/lint/black_linter.py index cfdc3d4e8..7c9a75f9c 100644 --- a/tools/lint/black_linter.py +++ b/tools/lint/black_linter.py @@ -176,11 +176,11 @@ def main() -> None: logging.basicConfig( format="<%(threadName)s:%(levelname)s> %(message)s", - level=logging.NOTSET - if args.verbose - else logging.DEBUG - if len(args.filenames) < 1000 - else logging.INFO, + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + ), stream=sys.stderr, ) diff --git a/torchrec/__init__.py b/torchrec/__init__.py index ddead68c9..29258f6f0 100644 --- a/torchrec/__init__.py +++ b/torchrec/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import torchrec.distributed # noqa import torchrec.quant # noqa from torchrec.fx import tracer # noqa @@ -25,3 +27,10 @@ KeyedTensor, ) from torchrec.streamable import Multistreamable, Pipelineable # noqa + +try: + # pyre-ignore[21] + # @manual=//torchrec/fb:version + from .version import __version__, github_version # noqa +except ImportError: + pass diff --git a/torchrec/csrc/dynamic_embedding/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/CMakeLists.txt index ca6fa20b5..f3bd73ad5 100644 --- a/torchrec/csrc/dynamic_embedding/CMakeLists.txt +++ b/torchrec/csrc/dynamic_embedding/CMakeLists.txt @@ -6,12 +6,19 @@ add_library(tde_cpp_objs OBJECT + bind.cpp + id_transformer_wrapper.cpp + ps.cpp details/clz_impl.cpp details/ctz_impl.cpp details/random_bits_generator.cpp - details/mixed_lfu_lru_strategy.cpp details/io_registry.cpp - details/io.cpp) + details/io.cpp + details/notification.cpp) + +if (BUILD_REDIS_IO) + add_subdirectory(details/redis) +endif() target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../) target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS}) diff --git a/torchrec/csrc/dynamic_embedding/bind.cpp b/torchrec/csrc/dynamic_embedding/bind.cpp new file mode 100644 index 000000000..e2faa9bbd --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/bind.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace torchrec { +TORCH_LIBRARY(tde, m) { + m.def("register_io", [](const std::string& name) { + IORegistry::Instance().register_plugin(name.c_str()); + }); + + m.class_("TransformResult") + .def_readonly("success", &TransformResult::success) + .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch); + + m.class_("IDTransformer") + .def(torch::init()) + .def("transform", &IDTransformerWrapper::transform) + .def("evict", &IDTransformerWrapper::evict) + .def("save", &IDTransformerWrapper::save); + + m.class_("LocalShardList") + .def(torch::init([]() { return c10::make_intrusive(); })) + .def("append", &LocalShardList::emplace_back); + + m.class_("FetchHandle").def("wait", &FetchHandle::wait); + + m.class_("PS") + .def(torch::init< + std::string, + c10::intrusive_ptr, + int64_t, + int64_t, + std::string, + int64_t>()) + .def("fetch", &PS::fetch) + .def("evict", &PS::evict); +} +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/id_transformer.h b/torchrec/csrc/dynamic_embedding/details/id_transformer.h new file mode 100644 index 000000000..3c826c840 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/id_transformer.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include + +namespace torchrec { + +namespace transform_default { +inline lxu_record_t no_update( + int64_t global_id, + int64_t cache_id, + std::optional record) { + return record.value_or(lxu_record_t{}); +}; + +inline void no_fetch(int64_t global_id, int64_t cache_id) {} +} // namespace transform_default + +class IDTransformer { + public: + /** + * Transform global ids to cache ids + * + * @tparam Update Update the eviction strategy tag type. Update LXU Record + * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used + * by dynamic embedding parameter server. + * + * @param global_ids Global ID vector + * @param cache_ids [out] Cache ID vector + * @param update update lambda. See `Update` doc. + * @param fetch fetch lambda. See `Fetch` doc. + * @return true if all transformed, otherwise need eviction. + */ + virtual bool transform( + std::span global_ids, + std::span cache_ids, + update_t update = transform_default::no_update, + fetch_t fetch = transform_default::no_fetch) = 0; + + /** + * Evict global ids from the transformer + * + * @param global_ids Global IDs to evict. + */ + virtual void evict(std::span global_ids) = 0; + + /** + * Create an iterator of the id transformer, a possible usecase is: + * + * auto iterator = transformer.iterator(); + * auto record = iterator(); + * while (record.has_value()) { + * // do sth with the record + * // ... + * // get next record + * auto record = iterator(); + * } + * + * @return the iterator created. + */ + virtual iterator_t iterator() const = 0; +}; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/io.cpp b/torchrec/csrc/dynamic_embedding/details/io.cpp index cf9556d7e..1c205254b 100644 --- a/torchrec/csrc/dynamic_embedding/details/io.cpp +++ b/torchrec/csrc/dynamic_embedding/details/io.cpp @@ -8,7 +8,7 @@ #include -namespace tde::details { +namespace torchrec { static constexpr std::string_view k_schema_separator = "://"; @@ -77,7 +77,7 @@ static void on_all_fetched(void* ctx) { delete c; } -void IO::pull( +void IO::fetch( const std::string& table_name, std::span global_ids, std::span col_ids, @@ -93,7 +93,7 @@ void IO::pull( ctx->tensors.resize( global_ids.size() * std::max(col_ids.size(), static_cast(1))); - IOPullParameter param{ + IOFetchParameter param{ .table_name = table_name.c_str(), .num_cols = static_cast(col_ids.size()), .num_global_ids = static_cast(global_ids.size()), @@ -104,7 +104,7 @@ void IO::pull( .on_all_fetched = on_all_fetched, }; param.on_complete_context = ctx.release(); - provider_.pull(instance_, param); + provider_.fetch(instance_, param); } struct PushContext { @@ -135,7 +135,7 @@ void IO::push( .col_ids = col_ids.data(), .global_ids = global_ids.data(), .num_optimizer_states = static_cast(os_ids.size()), - .optimizer_stats_ids = os_ids.data(), + .optimizer_state_ids = os_ids.data(), .num_offsets = static_cast(offsets.size()), .offsets = offsets.data(), .data = data.data(), @@ -145,4 +145,4 @@ void IO::push( provider_.push(instance_, param); } -} // namespace tde::details +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/io.h b/torchrec/csrc/dynamic_embedding/details/io.h index ae2d19195..2f4ec64ac 100644 --- a/torchrec/csrc/dynamic_embedding/details/io.h +++ b/torchrec/csrc/dynamic_embedding/details/io.h @@ -12,7 +12,7 @@ #include #include -namespace tde::details { +namespace torchrec { class IO { public: @@ -44,7 +44,7 @@ class IO { * copied inside, so it is safe to free `col_ids`/`global_ids` before * `on_fetch_complete`. */ - void pull( + void fetch( const std::string& table_name, std::span global_ids, std::span col_ids, @@ -89,4 +89,4 @@ class IO { void* instance_{}; }; -} // namespace tde::details +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/io_parameter.h b/torchrec/csrc/dynamic_embedding/details/io_parameter.h new file mode 100644 index 000000000..ab92db9cd --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/io_parameter.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torchrec { + +using GlobalIDFetchCallback = void (*)( + void* ctx, + uint32_t offset, + uint32_t optimizer_state, + void* data, + uint32_t data_len); + +struct IOFetchParameter { + const char* table_name; + uint32_t num_cols; + uint32_t num_global_ids; + const int64_t* col_ids; + const int64_t* global_ids; + uint32_t num_optimizer_states; + void* on_complete_context; + GlobalIDFetchCallback on_global_id_fetched; + void (*on_all_fetched)(void* ctx); +}; + +struct IOPushParameter { + const char* table_name; + uint32_t num_cols; + uint32_t num_global_ids; + const int64_t* col_ids; + const int64_t* global_ids; + uint32_t num_optimizer_states; + const uint32_t* optimizer_state_ids; + uint32_t num_offsets; + const uint64_t* offsets; + const void* data; + void* on_complete_context; + void (*on_push_complete)(void* ctx); +}; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp index 614da6d2d..b3744dedf 100644 --- a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp +++ b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp @@ -10,7 +10,7 @@ #include #include -namespace tde::details { +namespace torchrec { void IORegistry::register_provider(IOProvider provider) { std::string type = provider.type; @@ -41,9 +41,9 @@ void IORegistry::register_plugin(const char* filename) { provider.finalize = reinterpret_cast(finalize_ptr); - auto pull_ptr = dlsym(ptr.get(), "IO_Pull"); - TORCH_CHECK(pull_ptr != nullptr, "cannot find IO_Pull symbol"); - provider.pull = reinterpret_cast(pull_ptr); + auto fetch_ptr = dlsym(ptr.get(), "IO_Fetch"); + TORCH_CHECK(fetch_ptr != nullptr, "cannot find IO_Fetch symbol"); + provider.fetch = reinterpret_cast(fetch_ptr); auto push_ptr = dlsym(ptr.get(), "IO_Push"); TORCH_CHECK(push_ptr != nullptr, "cannot find IO_Push symbol"); @@ -72,4 +72,4 @@ IORegistry& IORegistry::Instance() { return instance; } -} // namespace tde::details +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.h b/torchrec/csrc/dynamic_embedding/details/io_registry.h index af727e84a..3feb39f4a 100644 --- a/torchrec/csrc/dynamic_embedding/details/io_registry.h +++ b/torchrec/csrc/dynamic_embedding/details/io_registry.h @@ -10,48 +10,17 @@ #include #include #include +#include #include #include #include -namespace tde::details { - -struct IOPullParameter { - const char* table_name; - uint32_t num_cols; - uint32_t num_global_ids; - const int64_t* col_ids; - const int64_t* global_ids; - uint32_t num_optimizer_states; - void* on_complete_context; - void (*on_global_id_fetched)( - void* ctx, - uint32_t offset, - uint32_t optimizer_state, - void* data, - uint32_t data_len); - void (*on_all_fetched)(void* ctx); -}; - -struct IOPushParameter { - const char* table_name; - uint32_t num_cols; - uint32_t num_global_ids; - const int64_t* col_ids; - const int64_t* global_ids; - uint32_t num_optimizer_states; - const uint32_t* optimizer_stats_ids; - uint32_t num_offsets; - const uint64_t* offsets; - const void* data; - void* on_complete_context; - void (*on_push_complete)(void* ctx); -}; +namespace torchrec { struct IOProvider { const char* type; void* (*initialize)(const char* cfg); - void (*pull)(void* instance, IOPullParameter cfg); + void (*fetch)(void* instance, IOFetchParameter cfg); void (*push)(void* instance, IOPushParameter cfg); void (*finalize)(void*); }; @@ -75,4 +44,4 @@ class IORegistry { std::vector dls_; }; -} // namespace tde::details +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h new file mode 100644 index 000000000..e18107770 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include + +namespace torchrec { + +class LXUStrategy { + public: + LXUStrategy() = default; + LXUStrategy(const LXUStrategy&) = delete; + LXUStrategy(LXUStrategy&& o) noexcept = default; + + virtual void update_time(uint32_t time) = 0; + virtual int64_t time(lxu_record_t record) = 0; + + virtual lxu_record_t update( + int64_t global_id, + int64_t cache_id, + std::optional val) = 0; + + /** + * Analysis all ids and returns the num_elems that are most need to evict. + * @param iterator Returns each global_id to ExtValue pair. Returns nullopt + * when at ends. + * @param num_to_evict + * @return + */ + virtual std::vector evict( + iterator_t iterator, + uint64_t num_to_evict) = 0; +}; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp deleted file mode 100644 index f6c486726..000000000 --- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace torchrec { -MixedLFULRUStrategy::MixedLFULRUStrategy(uint16_t min_used_freq_power) - : min_lfu_power_(min_used_freq_power), time_(new std::atomic()) {} - -void MixedLFULRUStrategy::update_time(uint32_t time) { - time_->store(time); -} - -MixedLFULRUStrategy::lxu_record_t MixedLFULRUStrategy::update( - int64_t global_id, - int64_t cache_id, - std::optional val) { - Record r{}; - r.time_ = time_->load(); - - if (C10_UNLIKELY(!val.has_value())) { - r.freq_power_ = min_lfu_power_; - } else { - auto freq_power = reinterpret_cast(&val.value())->freq_power_; - bool should_carry = generator_.is_next_n_bits_all_zero(freq_power); - if (should_carry) { - ++freq_power; - } - r.freq_power_ = freq_power; - } - return *reinterpret_cast(&r); -} - -} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h index 85f7bbd75..beb015c18 100644 --- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h +++ b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h @@ -7,7 +7,8 @@ */ #pragma once -#include +#include +#include #include #include #include @@ -37,49 +38,56 @@ namespace torchrec { * * Use `update` to update extended value when every time global id that used. */ -class MixedLFULRUStrategy { +class MixedLFULRUStrategy : public LXUStrategy { public: - using lxu_record_t = uint32_t; - using transformer_record_t = TransformerRecord; - - static constexpr std::string_view type_ = "mixed_lru_lfu"; - /** * @param min_used_freq_power min usage is 2^min_used_freq_power. Set this to * avoid recent values evict too fast. */ - explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5); + explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5) + : min_lfu_power_(min_used_freq_power), + time_(new std::atomic()) {} MixedLFULRUStrategy(const MixedLFULRUStrategy&) = delete; MixedLFULRUStrategy(MixedLFULRUStrategy&& o) noexcept = default; - void update_time(uint32_t time); - template - static int64_t time(T record) { - static_assert(sizeof(T) == sizeof(Record)); - return static_cast(reinterpret_cast(&record)->time_); + void update_time(lxu_record_t time) override { + time_->store(time); + } + + int64_t time(lxu_record_t record) override { + return static_cast(reinterpret_cast(&record)->time); } - lxu_record_t - update(int64_t global_id, int64_t cache_id, std::optional val); + lxu_record_t update( + int64_t global_id, + int64_t cache_id, + std::optional val) override { + Record r{}; + r.time = time_->load(); + + if (!val.has_value()) [[unlikely]] { + r.freq_power = min_lfu_power_; + } else { + auto freq_power = reinterpret_cast(&val.value())->freq_power; + bool should_carry = generator_.is_next_n_bits_all_zero(freq_power); + if (should_carry) { + ++freq_power; + } + r.freq_power = freq_power; + } + return *reinterpret_cast(&r); + } struct EvictItem { - int64_t global_id_; - lxu_record_t record_; + int64_t global_id; + lxu_record_t record; bool operator<(const EvictItem& item) const { - return record_ < item.record_; + return record < item.record; } }; - /** - * Analysis all ids and returns the num_elems that are most need to evict. - * @param iterator Returns each global_id to ExtValue pair. Returns nullopt - * when at ends. - * @param num_to_evict - * @return - */ - template - static std::vector evict(Iterator iterator, uint64_t num_to_evict) { + std::vector evict(iterator_t iterator, uint64_t num_to_evict) { std::priority_queue items; while (true) { auto val = iterator(); @@ -87,8 +95,8 @@ class MixedLFULRUStrategy { break; } EvictItem item{ - .global_id_ = val->global_id_, - .record_ = reinterpret_cast(&val->lxu_record_)->ToUint32(), + .global_id = val->global_id, + .record = reinterpret_cast(&val->lxu_record)->ToUint32(), }; if (items.size() == num_to_evict) { if (!(item < items.top())) { @@ -105,7 +113,7 @@ class MixedLFULRUStrategy { result.reserve(items.size()); while (!items.empty()) { auto item = items.top(); - result.emplace_back(item.global_id_); + result.emplace_back(item.global_id); items.pop(); } std::reverse(result.begin(), result.end()); @@ -114,17 +122,16 @@ class MixedLFULRUStrategy { // Record should only be used in unittest or internally. struct Record { - uint32_t time_ : 27; - uint16_t freq_power_ : 5; + uint32_t time : 27; + uint16_t freq_power : 5; [[nodiscard]] uint32_t ToUint32() const { - return time_ | (freq_power_ << (32 - 5)); + return time | (freq_power << (32 - 5)); } }; - - private: static_assert(sizeof(Record) == sizeof(lxu_record_t)); + private: RandomBitsGenerator generator_; uint16_t min_lfu_power_; std::unique_ptr> time_; diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h index 9cf2e694d..8f7718340 100644 --- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h +++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h @@ -9,33 +9,13 @@ #pragma once #include #include +#include #include #include #include namespace torchrec { -namespace transform_default { - -template -inline LXURecord no_update( - std::optional record, - int64_t global_id, - int64_t cache_id) { - return record.value_or(LXURecord{}); -}; - -inline void no_fetch(int64_t global_id, int64_t cache_id) {} - -} // namespace transform_default - -template -struct TransformerRecord { - int64_t global_id_; - int64_t cache_id_; - LXURecord lxu_record_; -}; - /** * NaiveIDTransformer * @@ -43,67 +23,27 @@ struct TransformerRecord { * @tparam LXURecord The extension type used for eviction strategy. * @tparam Bitmap The bitmap class to record the free cache ids. */ -template > -class NaiveIDTransformer { +template > +class NaiveIDTransformer : public IDTransformer { public: - using lxu_record_t = LXURecord; - using record_t = TransformerRecord; - static constexpr std::string_view type_ = "naive"; - explicit NaiveIDTransformer(int64_t num_embedding); - NaiveIDTransformer(const NaiveIDTransformer&) = delete; - NaiveIDTransformer(NaiveIDTransformer&&) noexcept = - default; + NaiveIDTransformer(const NaiveIDTransformer&) = delete; + NaiveIDTransformer(NaiveIDTransformer&&) noexcept = default; - /** - * Transform global ids to cache ids - * - * @tparam Update Update the eviction strategy tag type. Update LXU Record - * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used - * by dynamic embedding parameter server. - * - * @param global_ids Global ID vector - * @param cache_ids [out] Cache ID vector - * @param update update lambda. See `Update` doc. - * @param fetch fetch lambda. See `Fetch` doc. - * @return true if all transformed, otherwise need eviction. - */ - template < - typename Update = decltype(transform_default::no_update), - typename Fetch = decltype(transform_default::no_fetch)> bool transform( std::span global_ids, std::span cache_ids, - Update update = transform_default::no_update, - Fetch fetch = transform_default::no_fetch); + update_t update = transform_default::no_update, + fetch_t fetch = transform_default::no_fetch) override; - /** - * Evict global ids from the transformer - * - * @param global_ids Global IDs to evict. - */ - void evict(std::span global_ids); + void evict(std::span global_ids) override; - /** - * Create an iterator of the id transformer, a possible usecase is: - * - * auto iterator = transformer.iterator(); - * auto record = iterator(); - * while (record.has_value()) { - * // do sth with the record - * // ... - * // get next record - * auto record = iterator(); - * } - * - * @return the iterator created. - */ - std::function()> iterator() const; + iterator_t iterator() const override; private: struct CacheValue { - int64_t cache_id_; - LXURecord lxu_record_; + int64_t cache_id; + lxu_record_t lxu_record; }; ska::flat_hash_map global_id2cache_value_; diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h index cd40f87d7..e0e38b2f7 100644 --- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h +++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h @@ -6,43 +6,40 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once #include #include namespace torchrec { -template -inline NaiveIDTransformer::NaiveIDTransformer( - int64_t num_embedding) +template +NaiveIDTransformer::NaiveIDTransformer(int64_t num_embedding) : bitmap_(num_embedding) { global_id2cache_value_.reserve(num_embedding); } -template -template -inline bool NaiveIDTransformer::transform( +template +bool NaiveIDTransformer::transform( std::span global_ids, std::span cache_ids, - Update update, - Fetch fetch) { + update_t update, + fetch_t fetch) { for (size_t i = 0; i < global_ids.size(); ++i) { int64_t global_id = global_ids[i]; auto iter = global_id2cache_value_.find(global_id); // cache_id is in [0, num_embedding) int64_t cache_id; if (iter != global_id2cache_value_.end()) { - cache_id = iter->second.cache_id_; - iter->second.lxu_record_ = - update(iter->second.lxu_record_, global_id, cache_id); + cache_id = iter->second.cache_id; + iter->second.lxu_record = + update(global_id, cache_id, iter->second.lxu_record); } else { // The transformer is full. - if (C10_UNLIKELY(bitmap_.full())) { + if (bitmap_.full()) [[unlikely]] { return false; } auto stored_cache_id = bitmap_.next_free_bit(); cache_id = stored_cache_id; - LXURecord record = update(std::nullopt, global_id, cache_id); + lxu_record_t record = update(global_id, cache_id, std::nullopt); global_id2cache_value_.emplace( global_id, CacheValue{stored_cache_id, record}); fetch(global_id, cache_id); @@ -52,30 +49,28 @@ inline bool NaiveIDTransformer::transform( return true; } -template -inline void NaiveIDTransformer::evict( - std::span global_ids) { +template +void NaiveIDTransformer::evict(std::span global_ids) { for (const int64_t global_id : global_ids) { auto iter = global_id2cache_value_.find(global_id); if (iter == global_id2cache_value_.end()) { continue; } - int64_t cache_id = iter->second.cache_id_; + int64_t cache_id = iter->second.cache_id; global_id2cache_value_.erase(iter); bitmap_.free_bit(cache_id); } } -template -inline auto NaiveIDTransformer::iterator() const - -> std::function()> { +template +iterator_t NaiveIDTransformer::iterator() const { auto iter = global_id2cache_value_.begin(); return [iter, this]() mutable -> std::optional { if (iter != global_id2cache_value_.end()) { auto record = record_t{ - .global_id_ = iter->first, - .cache_id_ = iter->second.cache_id_, - .lxu_record_ = iter->second.lxu_record_, + .global_id = iter->first, + .cache_id = iter->second.cache_id, + .lxu_record = iter->second.lxu_record, }; iter++; return record; diff --git a/torchrec/csrc/dynamic_embedding/details/notification.cpp b/torchrec/csrc/dynamic_embedding/details/notification.cpp new file mode 100644 index 000000000..297b05393 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/notification.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "notification.h" + +namespace torchrec { +void Notification::done() { + { + std::lock_guard guard(mu_); + set_ = true; + } + cv_.notify_all(); +} +void Notification::wait() { + std::unique_lock lock(mu_); + cv_.wait(lock, [this] { return set_; }); +} + +void Notification::clear() { + set_ = false; +} +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/notification.h b/torchrec/csrc/dynamic_embedding/details/notification.h new file mode 100644 index 000000000..5d6c04f03 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/notification.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +namespace torchrec { + +/** + * Multi-thread notification + */ +class Notification : public torch::CustomClassHolder { + public: + Notification() = default; + + void done(); + void wait(); + + /** + * Clear the set status. + * + * NOTE: Clear is not thread-safe. + */ + void clear(); + + private: + bool set_{false}; + std::mutex mu_; + std::condition_variable cv_; +}; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp index 4b307842d..0fada452d 100644 --- a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp +++ b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp @@ -6,17 +6,16 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include namespace torchrec { bool BitScanner::is_next_n_bits_all_zero(uint16_t& n_bits) { - if (C10_UNLIKELY((n_bits == 0))) { + if ((n_bits == 0)) [[unlikely]] { return true; } - if (C10_UNLIKELY(array_idx_ == size_)) { + if (array_idx_ == size_) [[unlikely]] { return true; } @@ -88,7 +87,7 @@ bool RandomBitsGenerator::is_next_n_bits_all_zero(uint16_t n_bits) { return false; } - if (C10_UNLIKELY(n_bits != 0)) { + if (n_bits != 0) [[unlikely]] { return is_next_n_bits_all_zero(n_bits); } else { return true; diff --git a/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt new file mode 100644 index 000000000..a9771391c --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +FetchContent_Declare( + hiredis + GIT_REPOSITORY https://github.com/redis/hiredis.git + GIT_TAG 06be7ff312a78f69237e5963cc7d24bc84104d3b +) + +FetchContent_GetProperties(hiredis) +if(NOT hiredis_POPULATED) + # Do not include hiredis in install targets + FetchContent_Populate(hiredis) + set(DISABLE_TESTS ON CACHE BOOL "Disable tests for hiredis") + add_subdirectory( + ${hiredis_SOURCE_DIR} ${hiredis_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() + +add_library(redis_io SHARED redis_io.cpp) +target_include_directories( + redis_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../) +target_include_directories(redis_io PUBLIC ${TORCH_INCLUDE_DIRS}) +target_compile_options(redis_io PUBLIC -fPIC) +target_link_libraries(redis_io PUBLIC hiredis::hiredis_static) diff --git a/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp new file mode 100644 index 000000000..71f04fd08 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp @@ -0,0 +1,505 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace torchrec::redis { + +int parse_integer(std::string_view param_str, std::string_view param_key) { + return std::stoi(std::string(param_str.substr(param_key.size()))); +} + +uint32_t parse_duration( + std::string_view param_str, + std::string_view param_key) { + auto param_value = param_str.substr(param_key.size()); + if (param_value.empty()) { + throw std::invalid_argument("no value for " + std::string(param_str)); + } + double duration; + if (param_value.ends_with("ms")) { + duration = + std::stod(std::string(param_value.substr(0, param_value.size() - 2))); + } else if (param_value.ends_with("s")) { + duration = + std::stod(std::string(param_value.substr(0, param_value.size() - 1))) * + 1000; + } else if (param_value.ends_with("m")) { + duration = + std::stod(std::string(param_value.substr(0, param_value.size() - 1))) * + 1000 * 60; + } else { + throw std::invalid_argument( + "no supported time unit (ms, s, m) in " + std::string(param_str)); + } + return static_cast(duration); +} + +Option parse_option(std::string_view config_str) { + Option option; + url_parser::Url url = url_parser::parse_url(http://github.com/config_str); + + if (url.authority.has_value()) { + option.username = std::move(url.authority->username); + option.password = std::move(url.authority->password); + } + + option.host = std::move(url.host); + + if (url.port.has_value()) { + option.port = url.port.value(); + } + + if (url.param.has_value()) { + std::string_view param_str = url.param.value(); + while (!param_str.empty()) { + auto and_pos = param_str.find("&&"); + std::string_view single_param_str; + if (and_pos != std::string_view::npos) { + single_param_str = param_str.substr(0, and_pos); + param_str = param_str.substr(and_pos + 2); + } else { + single_param_str = param_str; + param_str = ""; + } + + if (single_param_str.starts_with("num_threads=")) { + option.num_io_threads = parse_integer(single_param_str, "num_threads="); + } else if (single_param_str.starts_with("db=")) { + option.db = parse_integer(single_param_str, "db="); + } else if (single_param_str.starts_with("prefix=")) { + option.prefix = single_param_str.substr(std::string("prefix=").size()); + } else if (single_param_str.starts_with("timeout=")) { + option.timeout_ms = parse_duration(single_param_str, "timeout="); + } else if (single_param_str.starts_with("heartbeat=")) { + option.heart_beat_interval_ms = + parse_duration(single_param_str, "heartbeat="); + } else if (single_param_str.starts_with("retry_limit=")) { + option.retry_limit = parse_integer(single_param_str, "retry_limit="); + } else if (single_param_str.starts_with("chunk_size=")) { + option.chunk_size = parse_integer(single_param_str, "chunk_size="); + } else { + throw std::invalid_argument( + "unknown parameter: " + std::string(single_param_str)); + } + } + } + + return option; +} + +Redis::Redis(Option opt) : opt_(std::move(opt)) { + TORCH_CHECK(opt_.num_io_threads != 0, "num_io_threads must not be empty"); + TORCH_CHECK( + opt_.heart_beat_interval_ms != 0, + "heart beat interval must not be zero."); + for (size_t i = 0; i < opt_.num_io_threads; ++i) { + start_thread(); + } +} + +void Redis::start_thread() { + auto connection = connect(); + heartbeat(connection); + + io_threads_.emplace_back( + [connection = std::move(connection), this]() mutable { + std::chrono::milliseconds heart_beat(opt_.heart_beat_interval_ms); + while (true) { + std::function todo; + bool heartbeat_timeout; + { + std::unique_lock lock(this->jobs_mutex_); + heartbeat_timeout = !jobs_not_empty_.wait_for( + lock, heart_beat, [this] { return !jobs_.empty(); }); + if (!heartbeat_timeout) { + todo = std::move(jobs_.front()); + jobs_.pop_front(); + } + } + + if (heartbeat_timeout) { + heartbeat(connection); + continue; + } + + if (!todo) { + break; + } + todo(connection); + } + }); +} + +void Redis::heartbeat(helper::ContextPtr& connection) { + for (uint32_t retry = 0; retry < opt_.retry_limit; ++retry) { + try { + auto reply = helper::ReplyPtr(reinterpret_cast( + redisCommand(connection.get(), "PING"))); + TORCH_CHECK( + reply && reply->type == REDIS_REPLY_STRING, + "Ping should return string"); + auto rsp = std::string_view(reply->str, reply->len); + TORCH_CHECK(rsp == "PONG", "ping/pong error"); + } catch (...) { + // reconnect if heart beat error + connection = connect(); + } + } +} + +helper::ContextPtr Redis::connect() const { + helper::ContextPtr connection; + if (opt_.timeout_ms == 0) { + connection = helper::ContextPtr(redisConnect(opt_.host.c_str(), opt_.port)); + } else { + struct timeval interval {}; + interval.tv_sec = opt_.timeout_ms / 1000; + interval.tv_usec = opt_.timeout_ms % 1000 * 1000; + connection = helper::ContextPtr( + redisConnectWithTimeout(opt_.host.c_str(), opt_.port, interval)); + } + TORCH_CHECK( + !connection->err, + "connect to %s:%d error occurred %s", + opt_.host, + opt_.port, + connection->errstr); + + if (!opt_.password.empty()) { + helper::ReplyPtr reply; + if (opt_.username.empty()) { + reply = helper::ReplyPtr(reinterpret_cast( + redisCommand(connection.get(), "AUTH %s", opt_.password.c_str()))); + } else { + reply = helper::ReplyPtr(reinterpret_cast(redisCommand( + connection.get(), + "AUTH %s %s", + opt_.username.c_str(), + opt_.password.c_str()))); + } + check_status("auth error", connection, reply); + } + + if (opt_.db != 0) { + auto reply = helper::ReplyPtr(reinterpret_cast( + redisCommand(connection.get(), "SELECT %d", opt_.db))); + check_status("select db error", connection, reply); + } + + return connection; +} + +Redis::~Redis() { + for (uint32_t i = 0; i < opt_.num_io_threads; ++i) { + jobs_.emplace_back(); + } + jobs_not_empty_.notify_all(); + for (auto& th : io_threads_) { + th.join(); + } +} + +static uint32_t CalculateChunkSizeByGlobalIDs( + uint32_t chunk_size, + uint32_t num_cols, + uint32_t num_os) { + static constexpr uint32_t low = 1; + return std::max( + chunk_size / std::max(num_cols, low) / std::max(num_os, low), low); +} + +struct RedisFetchContext { + std::atomic num_complete_ids{0}; + uint32_t chunk_size; + std::string table_name; + std::vector global_ids; + std::vector col_ids; + uint32_t num_optimizer_states; + void* on_complete_context; + void (*on_global_id_fetched)( + void* ctx, + uint32_t gid_offset, + uint32_t optimizer_state, + void* data, + uint32_t data_len); + void (*on_all_fetched)(void* ctx); + + explicit RedisFetchContext(uint32_t chunk_size, IOFetchParameter param) + : chunk_size(CalculateChunkSizeByGlobalIDs( + chunk_size, + param.num_cols, + param.num_optimizer_states)), + table_name(param.table_name), + global_ids(param.global_ids, param.global_ids + param.num_global_ids), + num_optimizer_states(param.num_optimizer_states), + on_complete_context(param.on_complete_context), + on_global_id_fetched(param.on_global_id_fetched), + on_all_fetched(param.on_all_fetched) { + if (param.num_cols == 0) { + col_ids.emplace_back(-1); + } else { + col_ids = + std::vector(param.col_ids, param.col_ids + param.num_cols); + } + } +}; + +void Redis::fetch(IOFetchParameter param) { + auto* fetch_param = new RedisFetchContext(opt_.chunk_size, param); + { + std::lock_guard guard(this->jobs_mutex_); + for (uint32_t i = 0; i < param.num_global_ids; + i += fetch_param->chunk_size) { + jobs_.emplace_back( + [i, fetch_param, this](helper::ContextPtr& connection) { + do_fetch(i, fetch_param, connection); + }); + } + } + jobs_not_empty_.notify_all(); +} + +void Redis::do_fetch( + uint32_t gid_offset, + void* fetch_param_void, + helper::ContextPtr& connection) const { + auto& fetch_param = *reinterpret_cast(fetch_param_void); + + uint32_t end = std::min( + gid_offset + fetch_param.chunk_size, + static_cast(fetch_param.global_ids.size())); + + auto loop = [&](auto&& callback) { + for (uint32_t i = gid_offset; i < end; ++i) { + int64_t gid = fetch_param.global_ids[i]; + for (uint32_t j = 0; j < fetch_param.col_ids.size(); ++j) { + auto& col_id = fetch_param.col_ids[j]; + for (uint32_t os_id = 0; os_id < fetch_param.num_optimizer_states; + ++os_id) { + callback(i * fetch_param.col_ids.size() + j, gid, col_id, os_id); + } + } + } + }; + + loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) { + redisAppendCommand( + connection.get(), + "GET %s_table_%s_gid_%d_cid_%d_osid_%d", + opt_.prefix.c_str(), + fetch_param.table_name.c_str(), + gid, + col_id, + os_id); + }); + + void* reply; + loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) { + int status = redisGetReply(connection.get(), &reply); + TORCH_CHECK( + status != REDIS_ERR, + "get reply error: %s, from redis %s, %d", + connection->errstr, + opt_.host, + opt_.port); + auto reply_ptr = helper::ReplyPtr(reinterpret_cast(reply)); + + if (reply_ptr->type == REDIS_REPLY_NIL) { + fetch_param.on_global_id_fetched( + fetch_param.on_complete_context, offset, os_id, nullptr, 0); + } else { + fetch_param.on_global_id_fetched( + fetch_param.on_complete_context, + offset, + os_id, + reply_ptr->str, + reply_ptr->len); + } + }); + + uint32_t n = end - gid_offset; + uint32_t target = fetch_param.global_ids.size(); + + if (fetch_param.num_complete_ids.fetch_add(n) + n == + target) { // last fetch complete + fetch_param.on_all_fetched(fetch_param.on_complete_context); + delete &fetch_param; + } +} + +struct RedisPushContext { + std::atomic num_complete_ids{0}; + uint32_t chunk_size; + std::string table_name; + std::span global_ids; + std::vector col_ids; + std::span os_ids; + std::span offsets; + const void* data; + void* on_complete_context; + void (*on_push_complete)(void*); + + RedisPushContext(uint32_t chunk_size, IOPushParameter param) + : chunk_size(CalculateChunkSizeByGlobalIDs( + chunk_size, + param.num_cols, + param.num_optimizer_states)), + table_name(param.table_name), + global_ids(param.global_ids, param.num_global_ids), + os_ids(param.optimizer_state_ids, param.num_optimizer_states), + offsets(param.offsets, param.num_offsets), + data(param.data), + on_complete_context(param.on_complete_context), + on_push_complete(param.on_push_complete) { + if (param.num_cols != 0) { + col_ids = + std::vector(param.col_ids, param.col_ids + param.num_cols); + } else { + col_ids.emplace_back(-1); + } + } +}; + +void Redis::push(IOPushParameter param) { + auto* ctx = new RedisPushContext(opt_.chunk_size, param); + { + std::lock_guard guard(this->jobs_mutex_); + for (uint32_t i = 0; i < param.num_global_ids; i += ctx->chunk_size) { + jobs_.emplace_back([i, ctx, this](helper::ContextPtr& connection) { + do_push(i, ctx, connection); + }); + } + } + jobs_not_empty_.notify_all(); +} +void Redis::do_push( + uint32_t gid_offset, + void* push_ctx_ptr, + helper::ContextPtr& connection) const { + auto& push_ctx = *reinterpret_cast(push_ctx_ptr); + + uint32_t end = gid_offset + push_ctx.chunk_size; + if (end > push_ctx.global_ids.size()) { + end = push_ctx.global_ids.size(); + } + + auto loop = [&](auto&& callback) { + for (uint32_t i = gid_offset; i < end; ++i) { + int64_t gid = push_ctx.global_ids[i]; + for (uint32_t j = 0; j < push_ctx.col_ids.size(); ++j) { + int64_t cid = push_ctx.col_ids[j]; + for (uint32_t k = 0; k < push_ctx.os_ids.size(); ++k) { + uint32_t os_id = push_ctx.os_ids[k]; + + uint32_t offset = k + j * push_ctx.os_ids.size() + + i * push_ctx.col_ids.size() * push_ctx.os_ids.size(); + callback(offset, gid, cid, os_id); + } + } + } + }; + + loop([&](uint32_t o, int64_t gid, int64_t cid, uint32_t os_id) { + uint64_t beg = push_ctx.offsets[o]; + uint64_t end = push_ctx.offsets[o + 1]; + + redisAppendCommand( + connection.get(), + "SET %s_table_%s_gid_%d_cid_%d_osid_%d %b", + opt_.prefix.c_str(), + push_ctx.table_name.c_str(), + gid, + cid, + os_id, + reinterpret_cast(push_ctx.data) + beg, + static_cast(end - beg)); + }); + + void* replay_ptr; + loop([&](...) { + int status = redisGetReply(connection.get(), &replay_ptr); + TORCH_CHECK( + status != REDIS_ERR, + "get reply error: %s, from redis %s, %d", + connection->errstr, + opt_.host, + opt_.port); + helper::ReplyPtr reply(reinterpret_cast(replay_ptr)); + check_status("reply should be ok", connection, reply); + }); + + uint32_t n = end - gid_offset; + uint32_t target = push_ctx.global_ids.size(); + if (push_ctx.num_complete_ids.fetch_add(n) + n == target) { + push_ctx.on_push_complete(push_ctx.on_complete_context); + delete &push_ctx; + } +} +void Redis::check_status( + std::string_view label, + helper::ContextPtr& connection, + helper::ReplyPtr& reply) const { + TORCH_CHECK( + connection->err == 0, + label, + " connection error: (", + connection->errstr, + "), from redis://", + opt_.host, + ":", + opt_.port); + + TORCH_CHECK( + reply->type == REDIS_REPLY_STATUS, + label, + " reply should be status, but actual type is ", + reply->type, + ". from redis://", + opt_.host, + ":", + opt_.port); + + auto status = std::string_view{reply->str, reply->len}; + TORCH_CHECK( + status == "OK", + label, + " reply status should be OK, but actual is ", + status, + ". from redis://", + opt_.host, + ":", + opt_.port); +} + +extern "C" { + +const char* IO_type = "redis"; + +void* IO_Initialize(const char* cfg) { + auto opt = parse_option(cfg); + return new Redis(opt); +} + +void IO_Finalize(void* instance) { + delete reinterpret_cast(instance); +} + +void IO_Fetch(void* instance, IOFetchParameter param) { + reinterpret_cast(instance)->fetch(param); +} + +void IO_Push(void* instance, IOPushParameter param) { + reinterpret_cast(instance)->push(param); +} +} + +} // namespace torchrec::redis diff --git a/torchrec/csrc/dynamic_embedding/details/redis/redis_io.h b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.h new file mode 100644 index 000000000..0c644db3c --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torchrec::redis { + +struct Option { + public: + std::string host; + std::string username; + std::string password; + uint16_t port{6379}; + uint16_t db{0}; + uint32_t num_io_threads{1}; + std::string prefix; + uint32_t timeout_ms{10000}; + uint32_t heart_beat_interval_ms{100000}; + uint32_t retry_limit{3}; + uint32_t chunk_size{100}; +}; + +Option parse_option(std::string_view config_str); + +namespace helper { +struct ContextDeleter { + void operator()(void* ctx) { + if (ctx == nullptr) { + return; + } + redisFree(reinterpret_cast(ctx)); + } +}; +using ContextPtr = std::unique_ptr; + +struct ReplyDeleter { + void operator()(void* cmd) { + if (cmd == nullptr) { + return; + } + freeReplyObject(cmd); + } +}; + +using ReplyPtr = std::unique_ptr; + +} // namespace helper + +class Redis { + public: + explicit Redis(Option opt); + + ~Redis(); + + void fetch(IOFetchParameter param); + + void push(IOPushParameter param); + + private: + void start_thread(); + void heartbeat(helper::ContextPtr& connection); + [[nodiscard]] helper::ContextPtr connect() const; + + void do_fetch( + uint32_t gid_offset, + void* fetch_param, + helper::ContextPtr& connection) const; + + void do_push( + uint32_t gid_offset, + void* push_ctx, + helper::ContextPtr& connection) const; + + void check_status( + std::string_view label, + helper::ContextPtr& connection, + helper::ReplyPtr& reply) const; + + Option opt_; + std::vector io_threads_; + std::deque> jobs_; + std::condition_variable jobs_not_empty_; + std::mutex jobs_mutex_; +}; + +extern "C" { + +extern const char* IO_type; + +void* IO_Initialize(const char* cfg); +void IO_Finalize(void* instance); +void IO_Fetch(void* instance, IOFetchParameter param); +void IO_Push(void* instance, IOPushParameter param); +} + +} // namespace torchrec::redis diff --git a/torchrec/csrc/dynamic_embedding/details/redis/url.h b/torchrec/csrc/dynamic_embedding/details/redis/url.h new file mode 100644 index 000000000..7907864bc --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/redis/url.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +namespace torchrec::url_parser { + +struct Authority { + std::string username; + std::string password; +}; + +struct Url { + std::optional authority; + std::string host; + std::optional port; + std::optional param; +}; + +inline Authority parse_authority(std::string_view authority_str) { + Authority authority; + + auto colon_pos = authority_str.find(':'); + if (colon_pos != std::string_view::npos) { + authority.username = authority_str.substr(0, colon_pos); + authority.password = authority_str.substr(colon_pos + 1); + } else { + // only username + authority.username = authority_str; + } + return authority; +} + +inline Url parse_url(/service/std::string_view url_str) { + Url url; + // (username (":" password)? "@")? host ":" port ("/" | "/?" param)? + // Assume there will only be one '@' + auto at_pos = url_str.find('@'); + if (at_pos != std::string_view::npos) { + Authority authority = parse_authority(url_str.substr(0, at_pos)); + url.authority = authority; + url_str = url_str.substr(at_pos + 1); + } + // There should be no '/' in host:port. + auto slash_pos = url_str.find('/'); + std::string_view host_port_str; + if (slash_pos != std::string_view::npos) { + host_port_str = url_str.substr(0, slash_pos); + url_str = url_str.substr(slash_pos + 1); + } else { + host_port_str = url_str; + url_str = ""; + } + + auto colon_pos = host_port_str.find(':'); + if (colon_pos != std::string_view::npos) { + url.host = host_port_str.substr(0, colon_pos); + auto port_str = host_port_str.substr(colon_pos + 1); + url.port = std::stoi(std::string(port_str)); + } else { + url.host = host_port_str; + } + + if (!url_str.empty()) { + if (url_str[0] != '?') { + throw std::invalid_argument("invalid parameter: " + std::string(url_str)); + } else { + url.param = url_str.substr(1); + } + } + + return url; +} + +} // namespace torchrec::url_parser diff --git a/torchrec/csrc/dynamic_embedding/details/types.h b/torchrec/csrc/dynamic_embedding/details/types.h new file mode 100644 index 000000000..571bfd3ba --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/details/types.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torchrec { + +using lxu_record_t = uint32_t; + +struct record_t { + int64_t global_id; + int64_t cache_id; + lxu_record_t lxu_record; +}; + +using iterator_t = std::function()>; +using update_t = + std::function)>; +using fetch_t = std::function; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.cpp b/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.cpp new file mode 100644 index 000000000..b8aece800 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torchrec { + +IDTransformerWrapper::IDTransformerWrapper( + int64_t num_embedding, + const std::string& id_transformer_type, + const std::string& lxu_strategy_type, + int64_t min_used_freq_power) + : time_(-1), last_save_time_(-1) { + TORCH_CHECK(id_transformer_type == "naive"); + TORCH_CHECK(lxu_strategy_type == "mixed_lru_lfu"); + transformer_ = + std::unique_ptr(new NaiveIDTransformer(num_embedding)); + strategy_ = std::unique_ptr( + new MixedLFULRUStrategy(min_used_freq_power)); +} + +c10::intrusive_ptr IDTransformerWrapper::transform( + std::vector global_id_list, + std::vector cache_id_list, + int64_t time) { + std::lock_guard lock(mu_); + torch::NoGradGuard no_grad; + TORCH_CHECK(time >= 0); + TORCH_CHECK(time >= time_, "Time cannot go backward"); + time_ = time; + TORCH_CHECK(global_id_list.size() == cache_id_list.size()); + strategy_->update_time(static_cast(time)); + { + int64_t total_num_embeddings = std::accumulate( + global_id_list.begin(), + global_id_list.end(), + int64_t(0), + [](int64_t v, auto&& tensor) -> int64_t { return v + tensor.numel(); }); + ids_to_fetch_.resize(2 * total_num_embeddings); + } + + update_t update = [this]( + int64_t global_id, + int64_t cache_id, + std::optional lxu_record) { + return strategy_->update(global_id, cache_id, lxu_record); + }; + std::atomic next_fetch_offset{0}; + fetch_t fetch = [&, this](int64_t global_id, int64_t cache_id) { + int64_t offset = next_fetch_offset.fetch_add(1); + ids_to_fetch_[2 * offset] = global_id; + ids_to_fetch_[2 * offset + 1] = cache_id; + }; + + bool ok = true; + for (int64_t i = 0; i < global_id_list.size(); ++i) { + auto& global_ids = global_id_list[i]; + auto& cache_ids = cache_id_list[i]; + ok = transformer_->transform( + std::span{ + global_ids.data_ptr(), + static_cast(global_ids.numel())}, + std::span{ + cache_ids.data_ptr(), + static_cast(cache_ids.numel())}, + update, + fetch); + if (!ok) { + break; + } + } + + return c10::make_intrusive( + ok, + at::from_blob( + ids_to_fetch_.data(), + {next_fetch_offset.load(), 2}, + torch::TensorOptions().dtype(c10::kLong).device(c10::kCPU))); +} + +torch::Tensor IDTransformerWrapper::evict(int64_t num_to_evict) { + std::lock_guard lock(mu_); + torch::NoGradGuard no_grad; + // get the global ids to evict. + std::vector global_ids_to_evict = + strategy_->evict(transformer_->iterator(), num_to_evict); + int64_t num_ids_to_evict = global_ids_to_evict.size(); + // get the cache id from transformer_ + std::vector cache_ids_to_evict(num_ids_to_evict); + transformer_->transform(global_ids_to_evict, cache_ids_to_evict); + // evict the global ids from transformer_ + transformer_->evict(global_ids_to_evict); + + std::vector ids_to_evict(num_ids_to_evict * 2); + for (int64_t i = 0; i < num_ids_to_evict; ++i) { + ids_to_evict[2 * i] = global_ids_to_evict[i]; + ids_to_evict[2 * i + 1] = cache_ids_to_evict[i]; + } + return torch::tensor(ids_to_evict, torch::dtype(torch::kLong)) + .reshape({num_ids_to_evict, 2}); +} + +torch::Tensor IDTransformerWrapper::save() { + std::lock_guard lock(mu_); + torch::NoGradGuard no_grad; + // traverse transformer_ and get the id with new timestamp. + std::vector ids; + iterator_t iterator = transformer_->iterator(); + while (true) { + auto val = iterator(); + if (!val.has_value()) [[unlikely]] { + break; + } + if (strategy_->time(val->lxu_record) > last_save_time_) { + ids.emplace_back(val->global_id); + ids.emplace_back(val->cache_id); + } + } + + last_save_time_ = time_; + int64_t num_ids = ids.size() / 2; + return torch::tensor(ids, torch::dtype(torch::kLong)).reshape({num_ids, 2}); +} + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.h b/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.h new file mode 100644 index 000000000..6b07252cb --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/id_transformer_wrapper.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include + +namespace torchrec { + +struct TransformResult : public torch::CustomClassHolder { + TransformResult(bool success, torch::Tensor ids_to_fetch) + : success(success), ids_to_fetch(ids_to_fetch) {} + + // Whether the fetch succeeded (if evicted is not necessary) + bool success; + // new ids to fetch from PS. + // shape of [num_to_fetch, 2], where each row is consist of + // the global id and cache id of each ID. + torch::Tensor ids_to_fetch; +}; + +class IDTransformerWrapper : public torch::CustomClassHolder { + public: + IDTransformerWrapper( + int64_t num_embedding, + const std::string& id_transformer_type, + const std::string& lxu_strategy_type, + int64_t min_used_freq_power = 5); + + c10::intrusive_ptr transform( + std::vector global_ids, + std::vector cache_ids, + int64_t time); + torch::Tensor evict(int64_t num_to_evict); + torch::Tensor save(); + + private: + std::mutex mu_; + std::unique_ptr transformer_; + std::unique_ptr strategy_; + std::vector ids_to_fetch_; + int64_t time_; + int64_t last_save_time_; +}; + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/ps.cpp b/torchrec/csrc/dynamic_embedding/ps.cpp new file mode 100644 index 000000000..a99477c3a --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/ps.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torchrec { + +c10::intrusive_ptr PS::fetch( + torch::Tensor ids_to_fetch, + int64_t time, + bool reinit, + double weight_init_min, + double weight_init_max) { + std::lock_guard lock(mu_); + torch::NoGradGuard no_grad; + + auto [local_global_ids, local_cache_ids] = filter_local_ids(ids_to_fetch); + if (local_global_ids.empty()) { + return c10::make_intrusive(time, c10::intrusive_ptr()); + } + + c10::intrusive_ptr notification; + { + std::unique_lock lock_fetch(fetch_notifications_mutex_); + fetch_notifications_.emplace_back( + time, c10::make_intrusive()); + notification = fetch_notifications_.back().second; + } + // Does not support multiple col ids at the moment. + std::vector col_ids{0}; + uint32_t num_os_ids = os_ids_.size(); + io_.fetch( + table_name_, + std::move(local_global_ids), + col_ids, + num_os_ids, + torch::kF32, + [=, this, cache_ids_to_fetch = std::move(local_cache_ids)](auto&& val) { + TORCH_CHECK(val.size() == cache_ids_to_fetch.size()); + for (uint32_t i = 0; i < cache_ids_to_fetch.size(); ++i) { + int64_t cache_id = cache_ids_to_fetch[i]; + auto& fetched = val[i]; + if (!fetched.defined()) { + if (reinit) { + std::vector tensors = get_tensor_views(cache_id); + tensors[0].uniform_(weight_init_min, weight_init_max); + // optimizer states will be set to zero + for (uint32_t j = 1; j < num_os_ids; ++j) { + tensors[j].zero_(); + } + } + continue; + } + + std::vector tensors = get_tensor_views(cache_id); + for (uint32_t j = 0; j < num_os_ids; ++j) { + tensors[j].copy_(fetched.slice(0, j, j + 1)); + } + } + notification->done(); + }); + // `unsafe_reclain_from_nonowning` is the `instrusive_ptr` version of + // `enable_shared_from_this` + return c10::make_intrusive( + time, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); +} + +void PS::evict(torch::Tensor ids_to_evict) { + std::lock_guard lock(mu_); + torch::NoGradGuard no_grad; + // make sure all previous fetches are done. + synchronize_fetch(); + + auto [local_global_ids, local_cache_ids] = filter_local_ids(ids_to_evict); + if (local_global_ids.empty()) { + return; + } + + // Does not support multiple col ids at the moment. + std::vector col_ids{0}; + uint32_t num_os_ids = os_ids_.size(); + uint32_t num_ids_to_fetch = local_global_ids.size(); + + Notification notification; + // Done first so that the Wait after preparing the first chunk won't stuck. + notification.done(); + // The shared data for all chunks. + std::vector offsets; + offsets.resize(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1); + // Evict by chunks + for (uint32_t i = 0; i < num_ids_to_fetch; i += num_ids_per_chunk_) { + uint32_t num_ids_in_chunk = std::min( + static_cast(num_ids_per_chunk_), num_ids_to_fetch - i); + uint32_t data_size = num_ids_in_chunk * num_os_ids * col_ids.size(); + uint32_t offsets_size = num_ids_in_chunk * num_os_ids * col_ids.size() + 1; + + std::vector all_tensors; + for (uint32_t j = i; j < i + num_ids_in_chunk; ++j) { + int64_t cache_id = local_cache_ids[j]; + std::vector tensors = get_tensor_views(cache_id); + all_tensors.insert(all_tensors.end(), tensors.begin(), tensors.end()); + } + torch::Tensor data = torch::cat(all_tensors, 0).cpu(); + TORCH_CHECK(data.numel() == data_size * col_size_); + + offsets[0] = 0; + for (uint32_t j = 0; j < all_tensors.size(); ++j) { + offsets[j + 1] = + offsets[j] + all_tensors[j].numel() * all_tensors[j].element_size(); + } + // waiting for the Push of last chunk finishes. + notification.wait(); + notification.clear(); + io_.push( + table_name_, + std::span{local_global_ids.data() + i, num_ids_in_chunk}, + col_ids, + os_ids_, + std::span{ + reinterpret_cast(data.data_ptr()), + data_size * sizeof(float)}, + std::span{offsets.data(), offsets_size}, + [¬ification] { notification.done(); }); + } + notification.wait(); +} + +void PS::synchronize_fetch(int64_t time) { + std::unique_lock lock( + fetch_notifications_mutex_, std::defer_lock); + + while (true) { + lock.lock(); + if (fetch_notifications_.empty() || + fetch_notifications_.front().first != time && time >= 0) { + lock.unlock(); + break; + } + auto notification = fetch_notifications_.front().second; + fetch_notifications_.pop_front(); + lock.unlock(); + + notification->wait(); + } +} + +std::vector PS::get_tensor_views(int64_t cache_id) { + for (auto& shard : *shards_) { + if (shard.has(cache_id)) { + return shard.get_tensor_view(cache_id); + } + } + TORCH_CHECK(false, "all local shards do not contain cache id ", cache_id); +} + +std::tuple, std::vector> PS::filter_local_ids( + const torch::Tensor& ids) { + std::vector local_global_ids; + std::vector local_cache_ids; + TORCH_CHECK(ids.is_contiguous()); + TORCH_CHECK(ids.dim() == 2); + auto* ids_ptr = ids.data_ptr(); + int64_t numel = ids.numel(); + for (int64_t i = 0; i < numel; i += 2) { + auto cache_id = ids_ptr[i + 1]; + if (std::any_of(shards_->begin(), shards_->end(), [&](auto&& shard) { + return shard.has(cache_id); + })) { + auto global_id = ids_ptr[i]; + local_global_ids.emplace_back(global_id); + local_cache_ids.emplace_back(cache_id); + } + } + return {std::move(local_global_ids), std::move(local_cache_ids)}; +} + +} // namespace torchrec diff --git a/torchrec/csrc/dynamic_embedding/ps.h b/torchrec/csrc/dynamic_embedding/ps.h new file mode 100644 index 000000000..81f2f0d68 --- /dev/null +++ b/torchrec/csrc/dynamic_embedding/ps.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include + +#include +#include +#include +#include + +namespace torchrec { + +/** + * @brief A local shard of embedding tensor with its range of row. + * It not only stores the parameter tensor of the shard, but also + * the tensor of this optimizer states. + * + */ +struct LocalShard { + int64_t row_start; + int64_t row_size; + std::vector tensors; + + /** + * @brief Check if a certain cache id is in this Shard + * + */ + [[nodiscard]] bool has(int64_t cache_id) const { + return row_start <= cache_id && cache_id < row_start + row_size; + } + + [[nodiscard]] std::vector get_tensor_view( + int64_t cache_id) const { + std::vector result; + result.reserve(tensors.size()); + for (auto& tensor : tensors) { + result.emplace_back( + tensor.slice(0, cache_id - row_start, cache_id - row_start + 1)); + } + return result; + } +}; + +/** + * @brief A helper class to store all the local shard on the current rank, + * basically `std::vecotr`. The reason for this class is that all + * shards could share the same refcount. + * + */ +class LocalShardList : public torch::CustomClassHolder { + using Container = std::vector; + + public: + void emplace_back( + int64_t row_start, + int64_t col_start, + int64_t row_size, + int64_t col_size, + std::vector tensors) { + // col_start/col_size not supported now. + shards_.emplace_back(LocalShard{ + .row_start = row_start, + .row_size = row_size, + .tensors = std::move(tensors)}); + } + + Container::const_iterator begin() const { + return shards_.begin(); + } + + Container::const_iterator end() const { + return shards_.end(); + } + + Container shards_; +}; + +class FetchHandle; + +class PS : public torch::CustomClassHolder { + public: + PS(std::string table_name, + c10::intrusive_ptr shards, + int64_t col_size, + int64_t num_optimizer_stats, + const std::string& io_config, + int64_t chunk_size) + : table_name_(std::move(table_name)), + shards_(std::move(shards)), + col_size_(col_size), + os_ids_(num_optimizer_stats), + io_(io_config), + num_ids_per_chunk_(chunk_size / col_size_ / num_optimizer_stats) { + TORCH_CHECK(num_ids_per_chunk_ > 0, "chunk size too small"); + for (int64_t i = 0; i < num_optimizer_stats; ++i) { + os_ids_[i] = i; + } + } + + /** + * @brief Fetch the embedding from remote PS into local GPU embedding + * asynchronously. + * + * @param ids_to_fetch ids to fetch, pairs of global id and cache id. + * @param time the timestamp of the fetch + * @param reinit whether to re-initialize the parameter and optimizer states + * if the id to fetch is not stored in PS. The parameter will be re-initialize + * with `uniform(weight_init_min, weight_init_max)` and the optimizer states + * will be re-initialized with 0. + * @return The handle used to synchronize the fetch. + */ + c10::intrusive_ptr fetch( + torch::Tensor ids_to_fetch, + int64_t time, + bool reinit, + double weight_init_min, + double weight_init_max); + /** + * @brief Synchronize all the fetches till timestamp `time`, + * if `time` is -1, then synchronize all previous fetches. + * + */ + void synchronize_fetch(int64_t time = -1); + + /** + * @brief Evict ids back to PS synchronously. + * + */ + void evict(torch::Tensor ids_to_evict); + + private: + std::vector get_tensor_views(int64_t cache_id); + std::tuple, std::vector> filter_local_ids( + const torch::Tensor& ids); + + // We need a mutex because the evict and fetch may happen in different thread. + std::mutex mu_; + std::mutex fetch_notifications_mutex_; + std::string table_name_; + c10::intrusive_ptr shards_; + int64_t col_size_; + std::vector os_ids_; + int64_t num_ids_per_chunk_; + IO io_; + std::deque>> + fetch_notifications_; +}; + +struct FetchHandle : public torch::CustomClassHolder { + public: + FetchHandle(int64_t time, c10::intrusive_ptr ps) + : time_(time), ps_(std::move(ps)) {} + void wait() { + if (ps_ != nullptr) + ps_->synchronize_fetch(time_); + } + + private: + int64_t time_; + c10::intrusive_ptr ps_; // not owned +}; + +} // namespace torchrec diff --git a/torchrec/datasets/__init__.py b/torchrec/datasets/__init__.py index 963ea21d6..6c52a92fd 100644 --- a/torchrec/datasets/__init__.py +++ b/torchrec/datasets/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Datasets Torchrec contains two popular recys datasets, the `Kaggle/Criteo Display Advertising `_ Dataset diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py index 9472c0e57..3ade6ee29 100644 --- a/torchrec/datasets/criteo.py +++ b/torchrec/datasets/criteo.py @@ -5,7 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import math +# pyre-strict + import os import shutil import time @@ -152,7 +153,7 @@ def criteo_kaggle( """`Kaggle/Criteo Display Advertising `_ Dataset Args: - root (str): local path to train or test dataset file. + path (str): local path to train or test dataset file. row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line. open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. @@ -182,17 +183,25 @@ def tsv_to_npys( out_dense_file: str, out_sparse_file: str, out_labels_file: str, + dataset_name: str = "criteo_1tb", path_manager_key: str = PATH_MANAGER_KEY, ) -> None: """ Convert one Criteo tsv file to three npy files: one for dense (np.float32), one for sparse (np.int32), and one for labels (np.int32). + The tsv file is expected to be part of the Criteo 1TB Click Logs Dataset ("criteo_1tb") + or the Criteo Kaggle Display Advertising Challenge dataset ("criteo_kaggle"). + + For the "criteo_kaggle" test set, we set the labels to -1 representing filler data, + because label data is not included in the "criteo_kaggle" test set. + Args: in_file (str): Input tsv file path. out_dense_file (str): Output dense npy file path. out_sparse_file (str): Output sparse npy file path. out_labels_file (str): Output labels npy file path. + dataset_name (str): The dataset name. "criteo_1tb" or "criteo_kaggle" is expected. path_manager_key (str): Path manager key used to load from different filesystems. @@ -200,6 +209,20 @@ def tsv_to_npys( None. """ + # Add fake label for criteo_kaggle test set, which does not include label data + def row_mapper_with_fake_label_constant( + row: List[str], + ) -> Tuple[List[int], List[int], int]: + label = -1 + dense = [int(row[i] or "0") for i in range(0, 0 + INT_FEATURE_COUNT)] + sparse = [ + int(row[i] or "0", 16) + for i in range( + 0 + INT_FEATURE_COUNT, 0 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT + ) + ] + return dense, sparse, label + def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: # Missing values are mapped to zero for both dense and sparse features label = int(row[0] or "0") @@ -213,8 +236,13 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: return dense, sparse, label dense, sparse, labels = [], [], [] - for (row_dense, row_sparse, row_label) in CriteoIterDataPipe( - [in_file], row_mapper=row_mapper + for row_dense, row_sparse, row_label in CriteoIterDataPipe( + [in_file], + row_mapper=( + row_mapper + if not (dataset_name == "criteo_kaggle" and "test" in in_file) + else row_mapper_with_fake_label_constant + ), ): dense.append(row_dense) sparse.append(row_sparse) @@ -224,7 +252,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: # using int64. Numpy will automatically handle dense values >= 2 ** 31. dense_np = np.array(dense, dtype=np.int32) del dense - sparse_np = np.array(sparse, dtype=np.int32) + sparse_np = np.array(sparse, dtype=np.int64) del sparse labels_np = np.array(labels, dtype=np.int32) del labels @@ -237,7 +265,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: labels_np = labels_np.reshape((-1, 1)) path_manager = PathManagerFactory().get(path_manager_key) - for (fname, arr) in [ + for fname, arr in [ (out_dense_file, dense_np), (out_sparse_file, sparse_np), (out_labels_file, labels_np), @@ -267,13 +295,13 @@ def get_shape_from_npy( return shape @staticmethod - def get_file_idx_to_row_range( + def get_file_row_ranges_and_remainder( lengths: List[int], rank: int, world_size: int, start_row: int = 0, last_row: Optional[int] = None, - ) -> Dict[int, Tuple[int, int]]: + ) -> Tuple[Dict[int, Tuple[int, int]], int]: """ Given a rank, world_size, and the lengths (number of rows) for a list of files, return which files and which portions of those files (represented as row ranges @@ -290,9 +318,9 @@ def get_file_idx_to_row_range( world_size (int): world size. Returns: - output (Dict[int, Tuple[int, int]]): Mapping of which files to the range in - those files to be handled by the rank. The keys of this dict are indices - of lengths. + output (Tuple[Dict[int, Tuple[int, int]], int]): First item is a mapping of files + to the range in those files to be handled by the rank. The keys of this dict are indices. + The second item is the remainder of dataset length / world size. """ # All ..._g variables are globals indices (meaning they range from 0 to @@ -302,22 +330,16 @@ def get_file_idx_to_row_range( total_length = sum(lengths) - start_row else: total_length = last_row - start_row + 1 - rows_per_rank = total_length // world_size - remainder = total_length % world_size # Global indices that rank is responsible for. All ranges (left, right) are # inclusive. - if rank < remainder: - rank_left_g = rank * (rows_per_rank + 1) - rank_right_g = (rank + 1) * (rows_per_rank + 1) - 1 - else: - rank_left_g = ( - remainder * (rows_per_rank + 1) + (rank - remainder) * rows_per_rank - ) - rank_right_g = rank_left_g + rows_per_rank - 1 - - rank_left_g += start_row - rank_right_g += start_row + rows_per_rank = total_length // world_size + remainder = total_length % world_size + rows_per_rank = np.array([rows_per_rank for _ in range(world_size)]) + rows_per_rank[:remainder] += 1 + rank_rows_bins_csr = np.cumsum([0] + list(rows_per_rank)) + rank_left_g = rank_rows_bins_csr[rank] + start_row + rank_right_g = rank_rows_bins_csr[rank + 1] - 1 + start_row output = {} @@ -339,7 +361,7 @@ def get_file_idx_to_row_range( overlap_right_l = overlap_right_g - file_left_g output[idx] = (overlap_left_l, overlap_right_l) - return output + return output, remainder @staticmethod def load_npy_range( @@ -411,7 +433,7 @@ def sparse_to_contiguous( that appear less than frequency_threshold amount of times will be remapped to have a value of 1. - Example transformation, frequenchy_threshold of 2: + Example transformation, frequency_threshold of 2: day_0_sparse.npy | col_0 | col_1 | ----------------- @@ -438,8 +460,8 @@ def sparse_to_contiguous( Args: in_files List[str]: Input directory of npy files. - out_dir (str): Output directory of processed npy files. - frequency_threshold: IDs occuring less than this frequency will be remapped to a value of 1. + output_dir (str): Output directory of processed npy files. + frequency_threshold: IDs occurring less than this frequency will be remapped to a value of 1. path_manager_key (str): Path manager key used to load from different filesystems. Returns: @@ -647,7 +669,7 @@ def shuffle( curr_first_row = curr_last_row # Directly copy over the last day's files since they will be used for validation and testing. - for (part, input_dir) in [ + for part, input_dir in [ ("sparse", input_dir_sparse), ("dense", input_dir_labels_and_dense), ("labels", input_dir_labels_and_dense), @@ -706,7 +728,10 @@ def __init__( batch_size: int, rank: int, world_size: int, + drop_last: Optional[bool] = False, shuffle_batches: bool = False, + shuffle_training_set: bool = False, + shuffle_training_set_random_seed: int = 0, mmap_mode: bool = False, hashes: Optional[List[int]] = None, path_manager_key: str = PATH_MANAGER_KEY, @@ -718,20 +743,68 @@ def __init__( self.batch_size = batch_size self.rank = rank self.world_size = world_size + self.drop_last = drop_last self.shuffle_batches = shuffle_batches + self.shuffle_training_set = shuffle_training_set + np.random.seed(shuffle_training_set_random_seed) self.mmap_mode = mmap_mode - self.hashes = hashes + self.hashes: np.ndarray = np.array(hashes).reshape((1, CAT_FEATURE_COUNT)) self.path_manager_key = path_manager_key self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) - self._load_data_for_rank() - self.num_rows_per_file: List[int] = [a.shape[0] for a in self.dense_arrs] - self.num_batches: int = math.ceil(sum(self.num_rows_per_file) / batch_size) + if shuffle_training_set and stage == "train": + self._shuffle_and_load_data_for_rank() + self.world_size = 1 + self.rank = 0 + else: + m = "r" if mmap_mode else None + self.dense_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.dense_paths + ] + self.sparse_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.sparse_paths + ] + self.labels_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.labels_paths + ] + len_d0 = len(self.dense_arrs[0]) + second_half_start_index = int(len_d0 // 2 + len_d0 % 2) + if stage == "val": + self.dense_arrs[0] = self.dense_arrs[0][:second_half_start_index, :] + self.sparse_arrs[0] = self.sparse_arrs[0][:second_half_start_index, :] + self.labels_arrs[0] = self.labels_arrs[0][:second_half_start_index, :] + elif stage == "test": + self.dense_arrs[0] = self.dense_arrs[0][second_half_start_index:, :] + self.sparse_arrs[0] = self.sparse_arrs[0][second_half_start_index:, :] + self.labels_arrs[0] = self.labels_arrs[0][second_half_start_index:, :] + # When mmap_mode is enabled, sparse features are hashed when + # samples are batched in def __iter__. Otherwise, the dataset has been + # preloaded with sparse features hashed in the preload stage, here: + if not self.mmap_mode and self.hashes is not None: + for sparse_arr in self.sparse_arrs: + sparse_arr %= self.hashes + + self.num_rows_per_file: List[int] = list(map(len, self.dense_arrs)) + total_rows = sum(self.num_rows_per_file) + self.num_full_batches: int = ( + total_rows // batch_size // self.world_size * self.world_size + ) + self.last_batch_sizes: np.ndarray = np.array( + [0 for _ in range(self.world_size)] + ) + remainder = total_rows % (self.world_size * batch_size) + if not self.drop_last and 0 < remainder: + if remainder < self.world_size: + self.num_full_batches -= self.world_size + self.last_batch_sizes += batch_size + else: + self.last_batch_sizes += remainder // self.world_size + self.last_batch_sizes[: remainder % self.world_size] += 1 # These values are the same for the KeyedJaggedTensors in all batches, so they # are computed once here. This avoids extra work from the KeyedJaggedTensor sync # functions. - self._num_ids_in_batch: int = CAT_FEATURE_COUNT * batch_size + self._num_ids_in_batch: int = CAT_FEATURE_COUNT * (batch_size + 1) self.keys: List[str] = DEFAULT_CAT_NAMES self.lengths: torch.Tensor = torch.ones( (self._num_ids_in_batch,), dtype=torch.int32 @@ -739,6 +812,7 @@ def __init__( self.offsets: torch.Tensor = torch.arange( 0, self._num_ids_in_batch + 1, dtype=torch.int32 ) + self._num_ids_in_batch -= CAT_FEATURE_COUNT self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size] self.offset_per_key: List[int] = [ batch_size * i for i in range(CAT_FEATURE_COUNT + 1) @@ -761,7 +835,7 @@ def _load_data_for_rank(self) -> None: dataset_len = samples_in_file - start_row last_row = start_row + dataset_len - 1 - file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range( + row_ranges, remainder = BinaryCriteoUtils.get_file_row_ranges_and_remainder( lengths=[ BinaryCriteoUtils.get_shape_from_npy( path, path_manager_key=self.path_manager_key @@ -773,13 +847,13 @@ def _load_data_for_rank(self) -> None: start_row=start_row, last_row=last_row, ) - + self.remainder = remainder self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], [] for arrs, paths in zip( [self.dense_arrs, self.sparse_arrs, self.labels_arrs], [self.dense_paths, self.sparse_paths, self.labels_paths], ): - for idx, (range_left, range_right) in file_idx_to_row_range.items(): + for idx, (range_left, range_right) in row_ranges.items(): arrs.append( BinaryCriteoUtils.load_npy_range( paths[idx], @@ -790,14 +864,48 @@ def _load_data_for_rank(self) -> None: ) ) - # When mmap_mode is enabled, the hash is applied in def __iter__, which is - # where samples are batched during training. - # Otherwise, the ML dataset is preloaded, and the hash is applied here in - # the preload stage, as shown: - if not self.mmap_mode and self.hashes is not None: - hashes_np = np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT)) - for sparse_arr in self.sparse_arrs: - sparse_arr %= hashes_np + def _shuffle_and_load_data_for_rank(self) -> None: + world_size = self.world_size + rank = self.rank + dense_arrs = [np.load(f, mmap_mode="r") for f in self.dense_paths] + sparse_arrs = [np.load(f, mmap_mode="r") for f in self.sparse_paths] + labels_arrs = [np.load(f, mmap_mode="r") for f in self.labels_paths] + num_rows_per_file = list(map(len, dense_arrs)) + total_rows = sum(num_rows_per_file) + permutation_arr = np.random.permutation(total_rows) + self.remainder = total_rows % world_size + rows_per_rank = total_rows // world_size + rows_per_rank = np.array([rows_per_rank for _ in range(world_size)]) + rows_per_rank[: self.remainder] += 1 + rank_rows_bins = np.cumsum(rows_per_rank) + rank_rows_bins_csr = np.cumsum([0] + list(rows_per_rank)) + + rows = rows_per_rank[rank] + d_sample, s_sample, l_sample = ( + dense_arrs[0][0], + sparse_arrs[0][0], + labels_arrs[0][0], + ) + shuffled_dense_arr = np.empty((rows, len(d_sample)), d_sample.dtype) + shuffled_sparse_arr = np.empty((rows, len(s_sample)), s_sample.dtype) + shuffled_labels_arr = np.empty((rows, len(l_sample)), l_sample.dtype) + + day_rows_bins_csr = np.cumsum(np.array([0] + num_rows_per_file)) + for i in range(len(dense_arrs)): + start = day_rows_bins_csr[i] + end = day_rows_bins_csr[i + 1] + indices_to_take = np.where( + rank == np.digitize(permutation_arr[start:end], rank_rows_bins) + )[0] + output_indices = ( + permutation_arr[start + indices_to_take] - rank_rows_bins_csr[rank] + ) + shuffled_dense_arr[output_indices] = dense_arrs[i][indices_to_take] + shuffled_sparse_arr[output_indices] = sparse_arrs[i][indices_to_take] + shuffled_labels_arr[output_indices] = labels_arrs[i][indices_to_take] + self.dense_arrs = [shuffled_dense_arr] + self.sparse_arrs = [shuffled_sparse_arr] + self.labels_arrs = [shuffled_labels_arr] def _np_arrays_to_batch( self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray @@ -855,33 +963,47 @@ def append_to_buffer( file_idx = 0 row_idx = 0 batch_idx = 0 - while batch_idx < self.num_batches: - buffer_row_count = 0 if buffer is None else none_throws(buffer)[0].shape[0] - if buffer_row_count == self.batch_size or file_idx == len(self.dense_arrs): - yield self._np_arrays_to_batch(*none_throws(buffer)) + buffer_row_count = 0 + cur_batch_size = ( + self.batch_size if self.num_full_batches > 0 else self.last_batch_sizes[0] + ) + while ( + batch_idx + < self.num_full_batches + (self.last_batch_sizes[0] > 0) * self.world_size + ): + if buffer_row_count == cur_batch_size or file_idx == len(self.dense_arrs): + if batch_idx % self.world_size == self.rank: + yield self._np_arrays_to_batch(*none_throws(buffer)) + buffer = None + buffer_row_count = 0 batch_idx += 1 - buffer = None + if 0 <= batch_idx - self.num_full_batches < self.world_size and ( + self.last_batch_sizes[0] > 0 + ): + cur_batch_size = self.last_batch_sizes[ + batch_idx - self.num_full_batches + ] else: rows_to_get = min( - self.batch_size - buffer_row_count, + cur_batch_size - buffer_row_count, self.num_rows_per_file[file_idx] - row_idx, ) + buffer_row_count += rows_to_get slice_ = slice(row_idx, row_idx + rows_to_get) - dense_inputs = self.dense_arrs[file_idx][slice_, :] - sparse_inputs = self.sparse_arrs[file_idx][slice_, :] - target_labels = self.labels_arrs[file_idx][slice_, :] + if batch_idx % self.world_size == self.rank: + dense_inputs = self.dense_arrs[file_idx][slice_, :] + sparse_inputs = self.sparse_arrs[file_idx][slice_, :] + target_labels = self.labels_arrs[file_idx][slice_, :] - if self.mmap_mode and self.hashes is not None: - sparse_inputs = sparse_inputs % np.array(self.hashes).reshape( - (1, CAT_FEATURE_COUNT) - ) + if self.mmap_mode and self.hashes is not None: + sparse_inputs = sparse_inputs % self.hashes - append_to_buffer( - dense_inputs, - sparse_inputs, - target_labels, - ) + append_to_buffer( + dense_inputs, + sparse_inputs, + target_labels, + ) row_idx += rows_to_get if row_idx >= self.num_rows_per_file[file_idx]: @@ -889,4 +1011,4 @@ def append_to_buffer( row_idx = 0 def __len__(self) -> int: - return self.num_batches + return self.num_full_batches // self.world_size + (self.last_batch_sizes[0] > 0) diff --git a/torchrec/datasets/movielens.py b/torchrec/datasets/movielens.py index ce53aa61c..0af5c57da 100644 --- a/torchrec/datasets/movielens.py +++ b/torchrec/datasets/movielens.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os from typing import Any, Callable, Dict, List, Optional, Union diff --git a/torchrec/datasets/random.py b/torchrec/datasets/random.py index f5743492f..9008622e5 100644 --- a/torchrec/datasets/random.py +++ b/torchrec/datasets/random.py @@ -5,7 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterator, List, Optional +# pyre-strict + +import itertools +import sys +from typing import cast, Iterator, List, Optional import torch from torch.utils.data.dataset import IterableDataset @@ -26,6 +30,8 @@ def __init__( manual_seed: Optional[int] = None, num_generated_batches: int = 10, num_batches: Optional[int] = None, + *, + min_ids_per_features: Optional[List[int]] = None, ) -> None: self.keys = keys @@ -33,6 +39,11 @@ def __init__( self.batch_size = batch_size self.hash_sizes = hash_sizes self.ids_per_features = ids_per_features + self.min_ids_per_features: List[int] = ( + min_ids_per_features + if min_ids_per_features + else [0] * len(ids_per_features) + ) self.num_dense = num_dense self.num_batches = num_batches self.num_generated_batches = num_generated_batches @@ -44,8 +55,8 @@ def __init__( self.generator = None self._generated_batches: List[Batch] = [ - self._generate_batch() - ] * num_generated_batches + self._generate_batch() for _ in range(num_generated_batches) + ] self.batch_index = 0 def __iter__(self) -> "_RandomRecBatch": @@ -70,21 +81,25 @@ def _generate_batch(self) -> Batch: lengths = [] for key_idx, _ in enumerate(self.keys): hash_size = self.hash_sizes[key_idx] - num_ids_in_batch = self.ids_per_features[key_idx] - - values.append( - torch.randint( - high=hash_size, - size=(num_ids_in_batch * self.batch_size,), - generator=self.generator, - ) + min_num_ids = self.min_ids_per_features[key_idx] + max_num_ids = self.ids_per_features[key_idx] + length = torch.randint( + min_num_ids, + max_num_ids + 1, + (self.batch_size,), + dtype=torch.int32, + generator=self.generator, ) - lengths.extend([num_ids_in_batch] * self.batch_size) + value = torch.randint( + 0, hash_size, (cast(int, length.sum()),), generator=self.generator + ) + lengths.append(length) + values.append(value) sparse_features = KeyedJaggedTensor.from_lengths_sync( keys=self.keys, values=torch.cat(values), - lengths=torch.tensor(lengths, dtype=torch.int32), + lengths=torch.cat(lengths), ) dense_features = torch.randn( @@ -120,13 +135,15 @@ class RandomRecDataset(IterableDataset[Batch]): modulo this value. hash_sizes (Optional[List[int]]): Max sparse id value per feature in keys. Each sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored. - ids_per_feature (int): Number of IDs per sparse feature. - ids_per_features (int): Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored. + ids_per_feature (Optional[int]): Number of IDs per sparse feature per sample. + ids_per_features (Optional[List[int]]): Number of IDs per sparse feature per sample in each key. Note, if this is used, ids_per_feature will be ignored. num_dense (int): Number of dense features. manual_seed (int): Seed for deterministic behavior. num_batches: (Optional[int]): Num batches to generate before raising StopIteration num_generated_batches int: Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch. If this value is negative, batches will be generated on the fly. + min_ids_per_feature (Optional[int]): Minimum number of IDs per features. + min_ids_per_features (Optional[List[int]]): Minimum number of IDs per sparse feature per sample in each key. Note, if this is used, min_ids_per_feature will be ignored. Example:: @@ -144,20 +161,21 @@ def __init__( self, keys: List[str], batch_size: int, - hash_size: Optional[int] = 100, + hash_size: Optional[int] = None, hash_sizes: Optional[List[int]] = None, - ids_per_feature: Optional[int] = 2, + ids_per_feature: Optional[int] = None, ids_per_features: Optional[List[int]] = None, num_dense: int = 50, manual_seed: Optional[int] = None, num_batches: Optional[int] = None, num_generated_batches: int = 10, + min_ids_per_feature: Optional[int] = None, + min_ids_per_features: Optional[List[int]] = None, ) -> None: super().__init__() if hash_sizes is None: - hash_size = hash_size or 100 - hash_sizes = [hash_size] * len(keys) + hash_sizes = [hash_size if hash_size else 100] * len(keys) assert hash_sizes is not None assert len(hash_sizes) == len( @@ -165,10 +183,20 @@ def __init__( ), "length of hash_sizes must be equal to the number of keys" if ids_per_features is None: - ids_per_feature = ids_per_feature or 2 - ids_per_features = [ids_per_feature] * len(keys) + ids_per_features = [ids_per_feature if ids_per_feature else 2] * len(keys) assert ids_per_features is not None + + if min_ids_per_features is None: + min_ids_per_feature = ( + min_ids_per_feature + if min_ids_per_feature is not None + else ids_per_feature + ) + min_ids_per_features = [ + min_ids_per_feature if min_ids_per_feature else 0 + ] * len(keys) + assert len(ids_per_features) == len( keys ), "length of ids_per_features must be equal to the number of keys" @@ -180,9 +208,14 @@ def __init__( ids_per_features=ids_per_features, num_dense=num_dense, manual_seed=manual_seed, - num_batches=num_batches, + num_batches=None, num_generated_batches=num_generated_batches, + min_ids_per_features=min_ids_per_features, ) + self.num_batches: int = cast(int, num_batches if not None else sys.maxsize) def __iter__(self) -> Iterator[Batch]: - return iter(self.batch_generator) + return itertools.islice(iter(self.batch_generator), self.num_batches) + + def __len__(self) -> int: + return self.num_batches diff --git a/torchrec/datasets/scripts/contiguous_preproc_criteo.py b/torchrec/datasets/scripts/contiguous_preproc_criteo.py index 473e32d91..59880e227 100644 --- a/torchrec/datasets/scripts/contiguous_preproc_criteo.py +++ b/torchrec/datasets/scripts/contiguous_preproc_criteo.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + # This script preprocesses the sparse feature files (binary npy) to such that # the IDs become contiguous (with frequency thresholding applied). # The results are saved in new binary (npy) files. diff --git a/torchrec/datasets/scripts/npy_preproc_criteo.py b/torchrec/datasets/scripts/npy_preproc_criteo.py index ae0ce9f2b..c03197d7d 100644 --- a/torchrec/datasets/scripts/npy_preproc_criteo.py +++ b/torchrec/datasets/scripts/npy_preproc_criteo.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + # This script preprocesses Criteo dataset tsv files to binary (npy) files. import argparse @@ -23,8 +25,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace: "--input_dir", type=str, required=True, - help="Input directory containing Criteo tsv files. Files in the directory " - "should be named day_{0-23}.", + help="Input directory containing Criteo tsv files." + "For criteo_1tb, files in the directory should be named day_{0-23}." + "For criteo_kaggle, files in the directory should be train.txt & test.txt.", ) parser.add_argument( "--output_dir", @@ -32,6 +35,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace: required=True, help="Output directory to store npy files.", ) + parser.add_argument( + "--dataset_name", + type=str, + choices=["criteo_1tb", "criteo_kaggle"], + default="criteo_1tb", + help="dataset for experiment, current support criteo_1tb, criteo_kaggle", + ) return parser.parse_args(argv) @@ -51,13 +61,21 @@ def main(argv: List[str]) -> None: input_dir = args.input_dir output_dir = args.output_dir - for i in range(24): - in_file_path = os.path.join(input_dir, f"day_{i}") + if args.dataset_name == "criteo_1tb": + in_files_l = [f"day_{i}" for i in range(24)] + out_files_l = in_files_l + else: + # criteo_kaggle code path + in_files_l = ["train.txt", "test.txt"] + out_files_l = ["train", "test"] + + for input, output in zip(in_files_l, out_files_l): + in_file_path = os.path.join(input_dir, input) if not os.path.exists(in_file_path): continue - dense_out_file_path = os.path.join(output_dir, f"day_{i}_dense.npy") - sparse_out_file_path = os.path.join(output_dir, f"day_{i}_sparse.npy") - labels_out_file_path = os.path.join(output_dir, f"day_{i}_labels.npy") + dense_out_file_path = os.path.join(output_dir, output + "_dense.npy") + sparse_out_file_path = os.path.join(output_dir, output + "_sparse.npy") + labels_out_file_path = os.path.join(output_dir, output + "_labels.npy") print( f"Processing {in_file_path}.\nOutput will be saved to\n{dense_out_file_path}" f"\n{sparse_out_file_path}\n{labels_out_file_path}" @@ -67,6 +85,7 @@ def main(argv: List[str]) -> None: dense_out_file_path, sparse_out_file_path, labels_out_file_path, + args.dataset_name, ) print(f"Done processing {in_file_path}.") diff --git a/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py b/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py index 16e4f3d98..e7bc0dde4 100644 --- a/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py +++ b/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os import tempfile import unittest diff --git a/torchrec/datasets/test_utils/criteo_test_utils.py b/torchrec/datasets/test_utils/criteo_test_utils.py index 46521d6ed..6f2e3a83d 100644 --- a/torchrec/datasets/test_utils/criteo_test_utils.py +++ b/torchrec/datasets/test_utils/criteo_test_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import contextlib import csv import os diff --git a/torchrec/datasets/tests/test_criteo.py b/torchrec/datasets/tests/test_criteo.py index f35369604..e402bc734 100644 --- a/torchrec/datasets/tests/test_criteo.py +++ b/torchrec/datasets/tests/test_criteo.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import contextlib import math import os @@ -70,7 +72,7 @@ def _validate_dataloader_sample( ) -> None: unbatched_samples = [{} for _ in range(self._sample_len(sample))] for k, batched_values in sample.items(): - for (idx, value) in enumerate(batched_values): + for idx, value in enumerate(batched_values): unbatched_samples[idx][k] = value for sample in unbatched_samples: self._validate_sample(sample, train=train) @@ -142,7 +144,7 @@ def test_tsv_to_npys(self) -> None: self.assertEqual(dense.shape, (num_rows, INT_FEATURE_COUNT)) self.assertEqual(dense.dtype, np.float32) self.assertEqual(sparse.shape, (num_rows, CAT_FEATURE_COUNT)) - self.assertEqual(sparse.dtype, np.int32) + self.assertEqual(sparse.dtype, np.int64) self.assertEqual(labels.shape, (num_rows, 1)) self.assertEqual(labels.dtype, np.int32) @@ -163,15 +165,15 @@ def test_get_shape_from_npy(self) -> None: self.assertEqual(sparse_shape, (num_rows, CAT_FEATURE_COUNT)) self.assertEqual(labels_shape, (num_rows, 1)) - def test_get_file_idx_to_row_range(self) -> None: + def test_get_file_row_ranges_and_remainder(self) -> None: lengths = [14, 17, 20] world_size = 3 expected = [{0: (0, 13), 1: (0, 2)}, {1: (3, 16), 2: (0, 2)}, {2: (3, 19)}] for i in range(world_size): self.assertEqual( - expected[i], - BinaryCriteoUtils.get_file_idx_to_row_range( + (expected[i], 0), + BinaryCriteoUtils.get_file_row_ranges_and_remainder( lengths=lengths, rank=i, world_size=world_size, @@ -379,13 +381,16 @@ def _test_dataset( dataset_len = num_rows // 2 lens = [] - remainder = dataset_len % world_size + remainder = dataset_len % (world_size * batch_size) + total_samples_count = 0 for rank in range(world_size): - incomplete_last_batch_size = ( - dataset_len // world_size % batch_size + int(rank < remainder) - ) - num_samples = dataset_len // world_size + int(rank < remainder) - num_batches = math.ceil(num_samples / batch_size) + end_batch_size = remainder // world_size + num_batches = math.ceil(dataset_len / (world_size * batch_size)) + if 0 < remainder < world_size: + num_batches -= 1 + end_batch_size = batch_size + if rank < (remainder % world_size): + end_batch_size += 1 datapipe = InMemoryBinaryCriteoIterDataPipe( stage=stage, dense_paths=[f[0] for f in files], @@ -400,33 +405,109 @@ def _test_dataset( self.assertEqual(datapipe_len, num_batches) len_ = 0 - samples_count = 0 for batch in datapipe: if stage in ["val", "test"] and len_ == 0 and rank == 0: self.assertEqual( batch.dense_features[0, 0].item(), dataset_start, ) - if len_ < num_batches - 1 or incomplete_last_batch_size == 0: + if len_ < num_batches - 1 or end_batch_size == 0: self._validate_batch(batch, batch_size=batch_size) else: - self._validate_batch( - batch, batch_size=incomplete_last_batch_size - ) + self._validate_batch(batch, batch_size=end_batch_size) len_ += 1 - samples_count += batch.dense_features.shape[0] + total_samples_count += batch.dense_features.shape[0] # Check that dataset __len__ matches true length. self.assertEqual(datapipe_len, len_) lens.append(len_) - self.assertEqual(samples_count, num_samples) - # Ensure all ranks return the correct number of batches. - if remainder > 0: - self.assertEqual(len(set(lens[:remainder])), 1) - self.assertEqual(len(set(lens[remainder:])), 1) - else: - self.assertEqual(len(set(lens)), 1) + # Ensure all ranks return the same number of batches. + self.assertEqual(len(set(lens)), 1) + # Ensure the number of samples read match the number of samples in the dataset + self.assertEqual(dataset_len, total_samples_count) + + def _test_in_memory_training_set_shuffle( + self, + rows_per_file: List[int], + batch_size: int, + world_size: int, + random_seed: int = 0, + ) -> None: + with contextlib.ExitStack() as stack: + num_rows_csr = np.cumsum([0] + rows_per_file) + dense, sparse, labels = [], [], [] + for i, _ in enumerate(rows_per_file): + start = num_rows_csr[i] + end = num_rows_csr[i + 1] + dense.append(np.mgrid[start:end, 0:INT_FEATURE_COUNT][0]) + sparse.append(np.mgrid[start:end, 0:CAT_FEATURE_COUNT][0]) + labels.append(np.mgrid[start:end, 0:1][0]) + files = [ + stack.enter_context( + self._create_dataset_npys( + num_rows=num_rows, + dense=dense[i], + sparse=sparse[i], + labels=labels[i], + ) + ) + for i, num_rows in enumerate(rows_per_file) + ] + hashes = [i + 1 for i in range(CAT_FEATURE_COUNT)] + dataset_len = num_rows = sum(rows_per_file) + remainder = dataset_len % world_size + rows_per_rank = dataset_len // world_size + rows_per_rank = np.array([rows_per_rank for _ in range(world_size)]) + rows_per_rank[:remainder] += 1 + lens = [] + total_samples_count = 0 + for rank in range(world_size): + shuffle_dataset_len = rows_per_rank[rank] + num_batches = math.ceil(shuffle_dataset_len / batch_size) + end_batch_size = shuffle_dataset_len % batch_size + + datapipe = InMemoryBinaryCriteoIterDataPipe( + stage="train", + dense_paths=[f[0] for f in files], + sparse_paths=[f[1] for f in files], + labels_paths=[f[2] for f in files], + batch_size=batch_size, + rank=rank, + world_size=world_size, + shuffle_training_set=True, + shuffle_training_set_random_seed=random_seed, + hashes=hashes, + ) + datapipe_len = len(datapipe) + self.assertEqual(datapipe_len, num_batches) + + np.random.seed(random_seed) + permutation_arr = np.random.permutation(num_rows) + src_arr = np.arange(num_rows) + target_shuffled_arr = np.empty(num_rows) + target_shuffled_arr[permutation_arr] = src_arr + target_shuffled_arr_ = np.array( + [b for a, b in sorted(zip(permutation_arr, src_arr))] + ) + np.testing.assert_array_equal(target_shuffled_arr, target_shuffled_arr_) + len_ = 0 + for batch in datapipe: + if len_ < num_batches - 1 or end_batch_size == 0: + self._validate_batch(batch, batch_size=batch_size) + else: + self._validate_batch(batch, batch_size=end_batch_size) + len_ += 1 + total_samples_count += batch.dense_features.shape[0] + + # Check that dataset __len__ matches true length. + self.assertEqual(datapipe_len, len_) + lens.append(len_) + + # Ensure all ranks return the same number of batches. + self.assertEqual(len(set(lens)), 1) + # Ensure the number of samples read match the number of samples in the dataset + self.assertEqual(dataset_len, total_samples_count) def test_dataset_small_files(self) -> None: self._test_dataset([1] * 20, 4, 2) @@ -434,15 +515,24 @@ def test_dataset_small_files(self) -> None: def test_dataset_random_sized_files(self) -> None: random.seed(0) self._test_dataset([random.randint(1, 100) for _ in range(100)], 16, 3) + # Test case where the global batch size does not evenly divide + # dataset_len but the local batch size does. + self._test_dataset([352], batch_size=32, world_size=8) - def test_dataset_val_and_test_sets(self) -> None: + def test_dataset_train_val_and_test_sets(self) -> None: for stage in ["train", "val", "test"]: # Test cases where batch_size evenly divides dataset_len. - self._test_dataset([100], 1, 2, stage=stage) - self._test_dataset([101], 1, 2, stage=stage) + self._test_dataset([100], batch_size=1, world_size=2, stage=stage) + self._test_dataset([101], batch_size=1, world_size=2, stage=stage) # Test cases where the first and only batch is an incomplete batch. - self._test_dataset([100], 32, 8, stage=stage) - self._test_dataset([101], 32, 8, stage=stage) + self._test_dataset([100], batch_size=32, world_size=8, stage=stage) + self._test_dataset([101], batch_size=32, world_size=8, stage=stage) # Test cases where batches are full size followed by a last batch that is incomplete. - self._test_dataset([10000], 128, 8, stage=stage) - self._test_dataset([10001], 128, 8, stage=stage) + self._test_dataset([10000], batch_size=128, world_size=8, stage=stage) + self._test_dataset([10001], batch_size=128, world_size=8, stage=stage) + + def test_in_memory_training_set_shuffle_driver(self) -> None: + self._test_in_memory_training_set_shuffle([100] * 10, 32, 4, random_seed=0) + self._test_in_memory_training_set_shuffle([100] * 10, 32, 4, random_seed=100) + self._test_in_memory_training_set_shuffle([10000], 128, 8, random_seed=0) + self._test_in_memory_training_set_shuffle([10000], 128, 8, random_seed=100) diff --git a/torchrec/datasets/tests/test_movielens.py b/torchrec/datasets/tests/test_movielens.py index aa121e46c..1c6395f2f 100644 --- a/torchrec/datasets/tests/test_movielens.py +++ b/torchrec/datasets/tests/test_movielens.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import contextlib import csv import os diff --git a/torchrec/datasets/tests/test_random.py b/torchrec/datasets/tests/test_random.py index 62ae7847d..f30d8e0ad 100644 --- a/torchrec/datasets/tests/test_random.py +++ b/torchrec/datasets/tests/test_random.py @@ -5,8 +5,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import itertools import unittest +from hypothesis import given, settings, strategies as st + from torchrec.datasets.random import RandomRecDataset @@ -17,6 +22,7 @@ def test_hash_per_feature_ids_per_feature(self) -> None: batch_size=16, hash_sizes=[100, 200], ids_per_features=[100, 200], + min_ids_per_features=[100, 200], num_dense=5, ) @@ -46,6 +52,7 @@ def test_hash_ids_per_feature(self) -> None: batch_size=16, hash_size=100, ids_per_features=[100, 200], + min_ids_per_features=[100, 200], num_dense=5, ) @@ -69,11 +76,29 @@ def test_hash_ids_per_feature(self) -> None: for batch in feat2: self.assertEqual(len(batch), 200) + # pyre-ignore + @given( + batch_size=st.sampled_from([2048, 4096, 8192]), + ) + @settings(max_examples=3, deadline=5000) # expected runtime <=500ms + def test_large_batch_size_deadline(self, batch_size: int) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=batch_size, + ids_per_features=[10, 20], + hash_size=100, + num_dense=5, + ) + iterator = iter(dataset) + for _ in range(5): + next(iterator) + def test_hash_ids(self) -> None: dataset = RandomRecDataset( keys=["feat1", "feat2"], batch_size=16, hash_size=100, + min_ids_per_feature=50, ids_per_feature=50, num_dense=5, ) @@ -104,6 +129,7 @@ def test_on_fly_batch_generation(self) -> None: batch_size=16, hash_size=100, ids_per_feature=50, + min_ids_per_feature=50, num_dense=5, num_generated_batches=-1, ) @@ -133,3 +159,23 @@ def test_on_fly_batch_generation(self) -> None: self.assertEqual(len(feat2), 16) for batch in feat2: self.assertEqual(len(batch), 50) + + # We want RandomRecDataset to support len() and + # itertools.chain() so the random dataloader can + # run the same code as real dataset dataloaders + # and substitute when wanted without issue. + def test_len_and_itertools_chain(self) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=16, + hash_size=100, + ids_per_feature=50, + num_dense=5, + num_generated_batches=-1, + num_batches=5, + ) + self.assertEqual(len(dataset), 5) + it = itertools.chain(iter(dataset), iter(dataset)) + for _ in range(10): + next(it) + self.assertRaises(StopIteration, lambda: next(it)) diff --git a/torchrec/datasets/tests/test_utils.py b/torchrec/datasets/tests/test_utils.py index 2311e1473..028b20cb4 100644 --- a/torchrec/datasets/tests/test_utils.py +++ b/torchrec/datasets/tests/test_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import random import unittest from typing import Any, Iterator, List, Tuple diff --git a/torchrec/datasets/utils.py b/torchrec/datasets/utils.py index 333fc49ed..f446d25b8 100644 --- a/torchrec/datasets/utils.py +++ b/torchrec/datasets/utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import csv import math import random @@ -39,11 +41,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": labels=self.labels.to(device=device, non_blocking=non_blocking), ) - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. + def record_stream(self, stream: torch.Stream) -> None: self.dense_features.record_stream(stream) self.sparse_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.labels.record_stream(stream) def pin_memory(self) -> "Batch": @@ -79,9 +79,7 @@ def train_filter( decimal_places: int, idx: int, ) -> bool: - return (key_fn(idx) % 10**decimal_places) < round( - train_perc * 10**decimal_places - ) + return (key_fn(idx) % 10**decimal_places) < round(train_perc * 10**decimal_places) def val_filter( diff --git a/torchrec/distributed/__init__.py b/torchrec/distributed/__init__.py index c4fe27849..f514a7156 100644 --- a/torchrec/distributed/__init__.py +++ b/torchrec/distributed/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Distributed Torchrec distributed provides the necessary modules and operations to enable model parallelism. @@ -24,7 +26,7 @@ * support for various compute kernels, which are optimized for compute device (CPU/GPU) and may include batching together embedding tables and/or optimizer fusion. - + * pipelined training through `TrainPipelineSparseDist` that overlaps dataloading device transfer (copy to GPU), inter*device communications (input_dist), and computation (forward, backward) for increased performance. @@ -35,6 +37,9 @@ from torchrec.distributed.comm import get_local_rank, get_local_size # noqa from torchrec.distributed.model_parallel import DistributedModelParallel # noqa from torchrec.distributed.train_pipeline import ( # noqa + DataLoadingThread, + EvalPipelineSparseDist, + PrefetchTrainPipelineSparseDist, TrainPipeline, TrainPipelineBase, TrainPipelineSparseDist, diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index e1e6f5e83..1aff0ecf6 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -5,32 +5,63 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc import copy +import inspect import itertools +import logging +import tempfile from dataclasses import dataclass -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + cast, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch import torch.distributed as dist -from fbgemm_gpu.split_table_batched_embeddings_ops import ( +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, DenseTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, - IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, + SparseType, SplitTableBatchedEmbeddingBagsCodegen, ) +from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( + PartiallyMaterializedTensor, +) from torch import nn +from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard +from torchrec.distributed.comm import get_local_rank, get_node_group_size +from torchrec.distributed.composable.table_batched_embedding_slice import ( + TableBatchedEmbeddingSlice, +) from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict from torchrec.distributed.embedding_types import ( compute_kernel_to_embedding_location, + DTensorMetadata, GroupedEmbeddingConfig, ) +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( Shard, ShardedTensor, ShardedTensorMetadata, + ShardingType, ShardMetadata, TensorProperties, ) @@ -39,9 +70,142 @@ data_type_to_sparse_type, pooling_type_to_pooling_mode, ) -from torchrec.optim.fused import FusedOptimizer, FusedOptimizerModule +from torchrec.optim.fused import ( + EmptyFusedOptimizer, + FusedOptimizer, + FusedOptimizerModule, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +logger: logging.Logger = logging.getLogger(__name__) + + +def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: + """ + Construct SSD TBE params dict from config and fused params dict. + """ + fused_params = config.fused_params or {} + + ssd_tbe_params: Dict[str, Any] = {} + + # drop the non-ssd tbe fused params + ssd_tbe_signature = inspect.signature( + SSDTableBatchedEmbeddingBags.__init__ + ).parameters.keys() + invalid_keys: List[str] = [] + + for key, value in fused_params.items(): + if key not in ssd_tbe_signature: + invalid_keys.append(key) + else: + ssd_tbe_params[key] = value + if len(invalid_keys) > 0: + logger.warning( + f"Dropping {invalid_keys} since they are not valid SSD TBE params." + ) + + # populate number cache sets, aka number of rows of the cache space + if "cache_sets" not in ssd_tbe_params: + cache_load_factor = fused_params.get("cache_load_factor") + if cache_load_factor: + cache_load_factor = fused_params.get("cache_load_factor") + logger.info( + f"Using cache load factor from fused params dict: {cache_load_factor}" + ) + else: + cache_load_factor = 0.2 + + local_rows_sum: int = sum(table.local_rows for table in config.embedding_tables) + ssd_tbe_params["cache_sets"] = max( + int(cache_load_factor * local_rows_sum / ASSOC), 1 + ) + + # populate init min and max + if ( + "ssd_uniform_init_lower" not in ssd_tbe_params + or "ssd_uniform_init_upper" not in ssd_tbe_params + ): + # Right now we do not support a per table init max and min. To use + # per table init max and min, either we allow it in SSD TBE, or we + # create one SSD TBE per table. + # TODO: Solve the init problem + mins = [table.get_weight_init_min() for table in config.embedding_tables] + maxs = [table.get_weight_init_max() for table in config.embedding_tables] + ssd_tbe_params["ssd_uniform_init_lower"] = sum(mins) / len( + config.embedding_tables + ) + ssd_tbe_params["ssd_uniform_init_upper"] = sum(maxs) / len( + config.embedding_tables + ) + + if "ssd_storage_directory" not in ssd_tbe_params: + ssd_tbe_params["ssd_storage_directory"] = tempfile.mkdtemp() + else: + if "@local_rank" in ssd_tbe_params["ssd_storage_directory"]: + # assume we have initialized a process group already + ssd_tbe_params["ssd_storage_directory"] = ssd_tbe_params[ + "ssd_storage_directory" + ].replace("@local_rank", str(get_local_rank())) + + if "weights_precision" not in ssd_tbe_params: + weights_precision = data_type_to_sparse_type(config.data_type) + ssd_tbe_params["weights_precision"] = weights_precision + + if "max_l1_cache_size" in fused_params: + l1_cache_size = fused_params.get("max_l1_cache_size") * 1024 * 1024 + max_dim: int = max(table.local_cols for table in config.embedding_tables) + weight_precision_bytes = ssd_tbe_params["weights_precision"].bit_rate() / 8 + max_cache_sets = ( + l1_cache_size / ASSOC / weight_precision_bytes / max_dim + ) # 100MB + + if ssd_tbe_params["cache_sets"] > int(max_cache_sets): + logger.warning( + f"cache_sets {ssd_tbe_params['cache_sets']} is larger than max_cache_sets {max_cache_sets} calculated " + "by max_l1_cache_size, cap at max_cache_sets instead" + ) + ssd_tbe_params["cache_sets"] = int(max_cache_sets) + + return ssd_tbe_params + + +class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): + def __init__( + self, + config: GroupedEmbeddingConfig, + emb_module: SSDTableBatchedEmbeddingBags, + pg: Optional[dist.ProcessGroup] = None, + ) -> None: + """ + Fused optimizer for SSD TBE. Right now it only supports tuning learning + rate. + """ + self._emb_module: SSDTableBatchedEmbeddingBags = emb_module + self._pg = pg + + # TODO: support optimizer states checkpointing once FBGEMM support + # split_optimizer_states API + + # pyre-ignore [33] + state: Dict[Any, Any] = {} + param_group: Dict[str, Any] = { + "params": [], + "lr": emb_module.get_learning_rate(), + } + + params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} + + super().__init__(params, state, [param_group]) + + def zero_grad(self, set_to_none: bool = False) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + class EmbeddingFusedOptimizer(FusedOptimizer): def __init__( # noqa C901 @@ -49,7 +213,18 @@ def __init__( # noqa C901 config: GroupedEmbeddingConfig, emb_module: SplitTableBatchedEmbeddingBagsCodegen, pg: Optional[dist.ProcessGroup] = None, + create_for_table: Optional[str] = None, + param_weight_for_table: Optional[nn.Parameter] = None, + embedding_weights_by_table: Optional[List[torch.Tensor]] = None, + all_optimizer_states: Optional[List[Dict[str, torch.Tensor]]] = None, ) -> None: + """ + Implementation of a FusedOptimizer. Designed as a base class Embedding kernels + + create_for_table is an optional flag, which if passed in only creates the optimizer for a single table. + This optimizer shares data with the broader optimizer (one per embedding kernel) + and is used to share step and LR changes + """ self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module self._pg = pg @@ -58,130 +233,256 @@ class ShardParams: optimizer_states: List[Optional[Tuple[torch.Tensor]]] local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] + dtensor_metadata: List[DTensorMetadata] - def to_rowwise_sharded_metadata( - local_metadata: ShardMetadata, - global_metadata: ShardedTensorMetadata, - sharding_dim: int, + def get_optimizer_single_value_shard_metadata_and_global_metadata( + table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, - ) -> Tuple[ShardMetadata, ShardedTensorMetadata]: - rw_shards: List[ShardMetadata] = [] - rw_local_shard: ShardMetadata = local_metadata - shards_metadata = global_metadata.shards_metadata - # column-wise sharding - # sort the metadata based on column offset and - # we construct the momentum tensor in row-wise sharded way - if sharding_dim == 1: - shards_metadata = sorted( - shards_metadata, key=lambda shard: shard.shard_offsets[1] + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + for offset, table_shard_metadata in enumerate(table_global_shards_metadata): + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_metadata + ] = ShardMetadata( + shard_sizes=[1], # single value optimizer state + shard_offsets=[offset], # offset increases by 1 for each shard + placement=table_shard_metadata.placement, ) - for idx, shard in enumerate(shards_metadata): - offset = shard.shard_offsets[0] - # for column-wise sharding, we still create row-wise sharded metadata for optimizer - # manually create a row-wise offset + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + single_value_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=torch.Size([len(table_global_shards_metadata)]), + tensor_properties=tensor_properties, + ) + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + single_value_optimizer_st_metadata, + ) + + def get_optimizer_rowwise_shard_metadata_and_global_metadata( + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + sharding_dim: int, + is_grid_sharded: bool = False, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + if sharding_dim == 1: + # column-wise sharding + # sort the metadata based on column offset and + # we construct the momentum tensor in row-wise sharded way + table_global_shards_metadata = sorted( + table_global_shards_metadata, + key=lambda shard: shard.shard_offsets[1], + ) - if sharding_dim == 1: - offset = idx * shard.shard_sizes[0] - rw_shard = ShardMetadata( - shard_sizes=[shard.shard_sizes[0]], + table_shard_metadata_to_optimizer_shard_metadata = {} + rolling_offset = 0 + for idx, table_shard_metadata in enumerate(table_global_shards_metadata): + offset = table_shard_metadata.shard_offsets[0] + + if is_grid_sharded: + # we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards + offset = rolling_offset + rolling_offset += table_shard_metadata.shard_sizes[0] + elif sharding_dim == 1: + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset + offset = idx * table_shard_metadata.shard_sizes[0] + + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_metadata + ] = ShardMetadata( + shard_sizes=[table_shard_metadata.shard_sizes[0]], shard_offsets=[offset], - placement=shard.placement, + placement=table_shard_metadata.placement, ) - if local_metadata == shard: - rw_local_shard = rw_shard + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + len_rw_shards = ( + len(table_shard_metadata_to_optimizer_shard_metadata) + if sharding_dim == 1 and not is_grid_sharded + else 1 + ) + # for grid sharding, the row dimension is replicated CW shard times + grid_shard_nodes = ( + len(table_global_shards_metadata) // get_node_group_size() + if is_grid_sharded + else 1 + ) + rowwise_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=torch.Size( + [table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes] + ), + tensor_properties=tensor_properties, + ) - rw_shards.append(rw_shard) + return ( + table_shard_metadata_to_optimizer_shard_metadata, + rowwise_optimizer_st_metadata, + ) + + def get_optimizer_pointwise_shard_metadata_and_global_metadata( + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + for table_shard_metadata in table_global_shards_metadata: + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_metadata + ] = ShardMetadata( + shard_sizes=table_shard_metadata.shard_sizes, + shard_offsets=table_shard_metadata.shard_offsets, + placement=table_shard_metadata.placement, + ) tensor_properties = TensorProperties( dtype=optimizer_state.dtype, - layout=global_metadata.tensor_properties.layout, + layout=optimizer_state.layout, requires_grad=False, - memory_format=global_metadata.tensor_properties.memory_format, - pin_memory=global_metadata.tensor_properties.pin_memory, ) - len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1 - rw_metadata = ShardedTensorMetadata( - shards_metadata=rw_shards, - size=torch.Size([global_metadata.size[0] * len_rw_shards]), + pointwise_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=table_global_metadata.size, tensor_properties=tensor_properties, ) - return rw_local_shard, rw_metadata + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + pointwise_optimizer_st_metadata, + ) # pyre-ignore [33] state: Dict[Any, Any] = {} param_group: Dict[str, Any] = { "params": [], - "lr": emb_module.optimizer_args.learning_rate, + "lr": emb_module.get_learning_rate(), } params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} # Fused optimizers use buffers (they don't use autograd) and we want to make sure # that state_dict look identical to no-fused version. - table_to_shard_params = {} + table_to_shard_params: Dict[str, ShardParams] = {} - split_embedding_weights = emb_module.split_embedding_weights() - split_optimizer_states = emb_module.split_optimizer_states() + embedding_weights_by_table = ( + embedding_weights_by_table or emb_module.split_embedding_weights() + ) - for table_config, optimizer_states, weight in itertools.zip_longest( + all_optimizer_states = all_optimizer_states or emb_module.get_optimizer_state() + optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {} + for ( + table_config, + optimizer_states, + weight, + ) in itertools.zip_longest( config.embedding_tables, - split_optimizer_states, - split_embedding_weights, + all_optimizer_states, + embedding_weights_by_table, ): + # When EmbeddingFusedOptimizer is created for composability, only create state + if create_for_table is not None and create_for_table != table_config.name: + continue if table_config.name not in table_to_shard_params: table_to_shard_params[table_config.name] = ShardParams( - optimizer_states=[], local_metadata=[], embedding_weights=[] + optimizer_states=[], + local_metadata=[], + embedding_weights=[], + dtensor_metadata=[], ) - + optimizer_state_values = None if optimizer_states: - for optimizer_state in optimizer_states: - assert table_config.local_rows == optimizer_state.size(0) - + optimizer_state_values = tuple(optimizer_states.values()) + for optimizer_state_value in optimizer_state_values: + assert ( + table_config.local_rows == optimizer_state_value.size(0) + or optimizer_state_value.nelement() == 1 # single value state + ) + optimizer_states_keys_by_table[table_config.name] = list( + optimizer_states.keys() + ) local_metadata = table_config.local_metadata table_to_shard_params[table_config.name].optimizer_states.append( - optimizer_states + optimizer_state_values ) table_to_shard_params[table_config.name].local_metadata.append( local_metadata ) + table_to_shard_params[table_config.name].dtensor_metadata.append( + table_config.dtensor_metadata + ) table_to_shard_params[table_config.name].embedding_weights.append(weight) seen_tables = set() for table_config in config.embedding_tables: + if create_for_table is not None and create_for_table != table_config.name: + continue if table_config.name in seen_tables: continue seen_tables.add(table_config.name) - table_config_global_metadata: Optional[ - ShardedTensorMetadata - ] = copy.deepcopy(table_config.global_metadata) + table_config_global_metadata: Optional[ShardedTensorMetadata] = ( + copy.deepcopy(table_config.global_metadata) + ) shard_params: ShardParams = table_to_shard_params[table_config.name] assert table_config_global_metadata is not None - local_weight_shards = [] - for local_weight, local_metadata in zip( - shard_params.embedding_weights, shard_params.local_metadata - ): - local_weight_shards.append(Shard(local_weight, local_metadata)) - table_config_global_metadata.tensor_properties.dtype = ( - local_weight.dtype - ) - table_config_global_metadata.tensor_properties.requires_grad = ( - local_weight.requires_grad + if create_for_table is None: + local_weight_shards = [] + for local_weight, local_metadata in zip( + shard_params.embedding_weights, shard_params.local_metadata + ): + local_weight_shards.append(Shard(local_weight, local_metadata)) + table_config_global_metadata.tensor_properties.dtype = ( + local_weight.dtype + ) + table_config_global_metadata.tensor_properties.requires_grad = ( + local_weight.requires_grad + ) + # TODO share this logic to create the same TableBatchedEmbeddingSlice in FusedModules below + weight = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_weight_shards, + sharded_tensor_metadata=table_config_global_metadata, + process_group=self._pg, ) - - weight = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=local_weight_shards, - sharded_tensor_metadata=table_config_global_metadata, - process_group=self._pg, - ) + param_key = table_config.name + ".weight" + else: + assert ( + param_weight_for_table is not None + ), "param_weight_for_table cannot be None when using create_for_table" + weight = param_weight_for_table + param_key = "" state[weight] = {} param_group["params"].append(weight) - param_key = table_config.name + ".weight" params[param_key] = weight # Setting optimizer states @@ -189,54 +490,128 @@ def to_rowwise_sharded_metadata( 1 if table_config.local_cols != table_config.embedding_dim else 0 ) + is_grid_sharded: bool = ( + True + if table_config.local_cols != table_config.embedding_dim + and table_config.local_rows != table_config.num_embeddings + else False + ) + if all( - [opt_state is not None for opt_state in shard_params.optimizer_states] + opt_state is not None for opt_state in shard_params.optimizer_states ): # pyre-ignore - def get_momentum(momentum_idx: int) -> ShardedTensor: + def get_sharded_optim_state( + momentum_idx: int, state_key: str + ) -> Union[ShardedTensor, DTensor]: assert momentum_idx > 0 momentum_local_shards: List[Shard] = [] + optimizer_sharded_tensor_metadata: ShardedTensorMetadata + + # pyre-ignore [16] + optim_state = shard_params.optimizer_states[0][momentum_idx - 1] + if ( + optim_state.nelement() == 1 and state_key != "momentum1" + ): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad + # single value state: one value per table + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = get_optimizer_single_value_shard_metadata_and_global_metadata( + table_config.global_metadata, + optim_state, + ) + elif optim_state.dim() == 1: + # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1 + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = get_optimizer_rowwise_shard_metadata_and_global_metadata( + table_config.global_metadata, + optim_state, + sharding_dim, + is_grid_sharded, + ) + else: + # pointwise state: param.shape == state.shape + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = get_optimizer_pointwise_shard_metadata_and_global_metadata( + table_config.global_metadata, + optim_state, + ) - sharded_tensor_metadata = table_config.global_metadata - for (optimizer_state, shard_param_local_metadata) in zip( + for optimizer_state, table_shard_local_metadata in zip( shard_params.optimizer_states, shard_params.local_metadata ): - - local_metadata = table_config.local_metadata - - if optimizer_state[momentum_idx - 1].dim() == 1: - ( - local_metadata, - sharded_tensor_metadata, - ) = to_rowwise_sharded_metadata( - shard_param_local_metadata, - table_config.global_metadata, - sharding_dim, + local_optimizer_shard_metadata = ( + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_local_metadata + ] + ) + momentum_local_shards.append( + Shard( optimizer_state[momentum_idx - 1], + local_optimizer_shard_metadata, ) - - assert local_metadata is not None - assert sharded_tensor_metadata is not None - momentum_local_shards.append( - Shard(optimizer_state[momentum_idx - 1], local_metadata) ) - return ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=momentum_local_shards, - sharded_tensor_metadata=sharded_tensor_metadata, - process_group=self._pg, - ) + # Convert optimizer state to DTensor if enabled + if table_config.dtensor_metadata: + # if rowwise state we do Shard(0), regardless of how the table is sharded + if optim_state.dim() == 1: + stride = (1,) + placements = ( + (Replicate(), DTensorShard(0)) + if table_config.dtensor_metadata.mesh.ndim == 2 + else (DTensorShard(0),) + ) + else: + stride = table_config.dtensor_metadata.stride + placements = table_config.dtensor_metadata.placements + + return DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[x.tensor for x in momentum_local_shards], + local_offsets=[ # pyre-ignore[6] + x.metadata.shard_offsets + for x in momentum_local_shards + ], + ), + device_mesh=table_config.dtensor_metadata.mesh, + placements=placements, + shape=optimizer_sharded_tensor_metadata.size, + stride=stride, + run_check=False, + ) + else: + # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=momentum_local_shards, + sharded_tensor_metadata=optimizer_sharded_tensor_metadata, + process_group=self._pg, + ) - if all( - # pyre-ignore - [len(opt_state) >= 1 for opt_state in shard_params.optimizer_states] - ): - state[weight][f"{table_config.name}.momentum1"] = get_momentum(1) - if all( + num_states: int = min( # pyre-ignore - [len(opt_state) >= 2 for opt_state in shard_params.optimizer_states] - ): - state[weight][f"{table_config.name}.momentum2"] = get_momentum(2) + [len(opt_state) for opt_state in shard_params.optimizer_states] + ) + optimizer_state_keys = [] + if num_states > 0: + optimizer_state_keys = optimizer_states_keys_by_table[ + table_config.name + ] + for cur_state_idx in range(0, num_states): + if cur_state_idx == 0: + # for backward compatibility + cur_state_key = "momentum1" + else: + cur_state_key = optimizer_state_keys[cur_state_idx] + + state[weight][f"{table_config.name}.{cur_state_key}"] = ( + get_sharded_optim_state(cur_state_idx + 1, cur_state_key) + ) super().__init__(params, state, [param_group]) @@ -249,8 +624,131 @@ def step(self, closure: Any = None) -> None: # pyre-ignore [16] self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + def set_optimizer_step(self, step: int) -> None: + self._emb_module.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + self._emb_module.update_hyper_parameters(params_dict) + + +def _gen_named_parameters_by_table_ssd( + emb_module: SSDTableBatchedEmbeddingBags, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, +) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Return an empty tensor to indicate that the table is on remote device. + """ + for table in config.embedding_tables: + table_name = table.name + # placeholder + weight: nn.Parameter = nn.Parameter(torch.empty(0)) + # pyre-ignore + weight._in_backward_optimizers = [EmptyFusedOptimizer()] + yield (table_name, weight) + + +def _gen_named_parameters_by_table_ssd_pmt( + emb_module: SSDTableBatchedEmbeddingBags, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, +) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Return an iterator over module parameters that are embedding tables, yielding both the table + name as well as the parameter itself. The embedding table is in the form of + PartiallyMaterializedTensor to support windowed access. + """ + pmts = emb_module.split_embedding_weights() + for table_config, pmt in zip(config.embedding_tables, pmts): + table_name = table_config.name + emb_table = pmt + weight: nn.Parameter = nn.Parameter(emb_table) + # pyre-ignore + weight._in_backward_optimizers = [EmptyFusedOptimizer()] + yield (table_name, weight) + + +def _gen_named_parameters_by_table_fused( + emb_module: SplitTableBatchedEmbeddingBagsCodegen, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, +) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + # TODO: move logic to FBGEMM to avoid accessing fbgemm internals + # Cache embedding_weights_by_table + embedding_weights_by_table = emb_module.split_embedding_weights() + # Cache all_optimizer_states + all_optimizer_states = emb_module.get_optimizer_state() + for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs): + table_name = config.embedding_tables[t_idx].name + if table_name not in table_name_to_count: + continue + table_count = table_name_to_count.pop(table_name) + if emb_module.weights_precision == SparseType.INT8: + dim += emb_module.int8_emb_row_dim_offset + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + offset = emb_module.weights_physical_offsets[t_idx] + weights: torch.Tensor + if location == EmbeddingLocation.DEVICE.value: + # pyre-fixme[9]: weights has type `Tensor`; used as `Union[Module, Tensor]`. + weights = emb_module.weights_dev + elif location == EmbeddingLocation.HOST.value: + # pyre-fixme[9]: weights has type `Tensor`; used as `Union[Module, Tensor]`. + weights = emb_module.weights_host + else: + # pyre-fixme[9]: weights has type `Tensor`; used as `Union[Module, Tensor]`. + weights = emb_module.weights_uvm + weight = TableBatchedEmbeddingSlice( + data=weights, + start_offset=offset, + end_offset=offset + table_count * rows * dim, + num_embeddings=-1, + embedding_dim=dim, + ) + # this reuses logic in EmbeddingFusedOptimizer but is per table + # pyre-ignore + weight._in_backward_optimizers = [ + EmbeddingFusedOptimizer( + config=config, + emb_module=emb_module, + pg=pg, + create_for_table=table_name, + param_weight_for_table=weight, + embedding_weights_by_table=embedding_weights_by_table, + all_optimizer_states=all_optimizer_states, + ) + ] + yield (table_name, weight) + + +def _gen_named_parameters_by_table_dense( + emb_module: DenseTableBatchedEmbeddingBagsCodegen, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, +) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + # TODO: move logic to FBGEMM to avoid accessing fbgemm internals + for t_idx, (rows, dim) in enumerate(emb_module.embedding_specs): + table_name = config.embedding_tables[t_idx].name + if table_name not in table_name_to_count: + continue + table_count = table_name_to_count.pop(table_name) + offset = emb_module.weights_physical_offsets[t_idx] + weight = TableBatchedEmbeddingSlice( + data=emb_module.weights, + start_offset=offset, + end_offset=offset + table_count * rows * dim, + num_embeddings=-1, + embedding_dim=dim, + ) + yield (table_name, weight) + + +SplitWeightType = TypeVar("SplitWeightType") -class BaseBatchedEmbedding(BaseEmbedding): + +class BaseBatchedEmbedding(BaseEmbedding, Generic[SplitWeightType]): def __init__( self, config: GroupedEmbeddingConfig, @@ -266,29 +764,50 @@ def __init__( self._weight_init_mins: List[float] = [] self._weight_init_maxs: List[float] = [] self._num_embeddings: List[int] = [] + self._embedding_dims: List[int] = [] self._local_cols: List[int] = [] + self._row_offset: List[int] = [] + self._col_offset: List[int] = [] self._feature_table_map: List[int] = [] - - for idx, config in enumerate(self._config.embedding_tables): - self._local_rows.append(config.local_rows) - self._weight_init_mins.append(config.get_weight_init_min()) - self._weight_init_maxs.append(config.get_weight_init_max()) - self._num_embeddings.append(config.num_embeddings) - self._local_cols.append(config.local_cols) - self._feature_table_map.extend([idx] * config.num_features()) + self.table_name_to_count: Dict[str, int] = {} + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + + for idx, table_config in enumerate(self._config.embedding_tables): + self._local_rows.append(table_config.local_rows) + self._weight_init_mins.append(table_config.get_weight_init_min()) + self._weight_init_maxs.append(table_config.get_weight_init_max()) + self._num_embeddings.append(table_config.num_embeddings) + self._embedding_dims.append(table_config.embedding_dim) + self._row_offset.append( + table_config.local_metadata.shard_offsets[0] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 0 + else 0 + ) + self._col_offset.append( + table_config.local_metadata.shard_offsets[1] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 1 + else 0 + ) + self._local_cols.append(table_config.local_cols) + self._feature_table_map.extend([idx] * table_config.num_features()) + if table_config.name not in self.table_name_to_count: + self.table_name_to_count[table_config.name] = 0 + self.table_name_to_count[table_config.name] += 1 def init_parameters(self) -> None: # initialize embedding weights assert len(self._num_embeddings) == len(self.split_embedding_weights()) - for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + for rows, emb_dim, weight_init_min, weight_init_max, param in zip( self._local_rows, self._local_cols, self._weight_init_mins, self._weight_init_maxs, self.split_embedding_weights(), ): - assert param.shape == (rows, emb_dim) - param.data.uniform_( + assert param.shape == (rows, emb_dim) # pyre-ignore[16] + param.data.uniform_( # pyre-ignore[16] weight_init_min, weight_init_max, ) @@ -309,13 +828,14 @@ def state_dict( self.flush() return get_state_dict( self._config.embedding_tables, + # pyre-ignore self.split_embedding_weights(), self._pg, destination, prefix, ) - def split_embedding_weights(self) -> List[torch.Tensor]: + def split_embedding_weights(self) -> List[SplitWeightType]: return self.emb_module.split_embedding_weights() @property @@ -326,8 +846,7 @@ def emb_module( DenseTableBatchedEmbeddingBagsCodegen, SplitTableBatchedEmbeddingBagsCodegen, IntNBitTableBatchedEmbeddingBagsCodegen, - ]: - ... + ]: ... @property def config(self) -> GroupedEmbeddingConfig: @@ -336,6 +855,9 @@ def config(self) -> GroupedEmbeddingConfig: def flush(self) -> None: pass + def purge(self) -> None: + pass + def named_split_embedding_weights( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -349,8 +871,179 @@ def named_split_embedding_weights( key = append_prefix(prefix, f"{config.name}.weight") yield key, param + def named_parameters_by_table( + self, + ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + """ + Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. + For a single table with multiple shards (i.e CW) these are combined into one table/weight. + Used in composability. + """ + for name, param in self._param_per_table.items(): + yield name, param + -class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule): +class KeyValueEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(config, pg, device) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + ssd_tbe_params = _populate_ssd_tbe_params(config) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=PoolingMode.NONE, + **ssd_tbe_params, + ).to(device) + + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd_pmt( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of SSD TBE is that we don't need to init weights. Hence skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + no_snapshot: bool = True, + ) -> Dict[str, Any]: + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() + + emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + # pyre-ignore [6] + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + # pyre-ignore [15] + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + def get_named_split_embedding_weights_snapshot( + self, prefix: str = "" + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(no_snapshot=False), + ): + key = append_prefix(prefix, f"{config.name}") + yield key, tensor + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + # pyre-ignore [15] + def split_embedding_weights( + self, no_snapshot: bool = True + ) -> List[PartiallyMaterializedTensor]: + return self.emb_module.split_embedding_weights(no_snapshot) + + +class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): def __init__( self, config: GroupedEmbeddingConfig, @@ -367,6 +1060,11 @@ def __init__( managed.append( compute_kernel_to_embedding_location(table.compute_kernel) ) + elif device is not None and device.type == "mtia": + compute_devices.append(ComputeDevice.MTIA) + # Set EmbeddingLocation.HOST to make embedding op in FBGEMM choose CPU path. + # But the tensor will still be created on MTIA with device type "mtia". + managed.append(EmbeddingLocation.HOST) else: compute_devices.append(ComputeDevice.CPU) managed.append(EmbeddingLocation.HOST) @@ -386,6 +1084,15 @@ def __init__( pooling_mode=PoolingMode.NONE, weights_precision=weights_precision, device=device, + table_names=[t.name for t in config.embedding_tables], + embedding_shard_info=list( + zip( + self._num_embeddings, + self._embedding_dims, + self._row_offset, + self._col_offset, + ) + ), **fused_params, ) ) @@ -394,6 +1101,14 @@ def __init__( self._emb_module, pg, ) + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( + _gen_named_parameters_by_table_fused( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) self.init_parameters() @property @@ -413,18 +1128,30 @@ def named_buffers( By convention, fused parameters are designated as buffers because they no longer have gradients available to external optimizers. """ - return self.named_split_embedding_weights(prefix, recurse, remove_duplicate) + # TODO can delete this override once SEA is removed + yield from () def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, nn.Parameter]]: - yield from () + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after SEA deprecation + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param def flush(self) -> None: self._emb_module.flush() + def purge(self) -> None: + self._emb_module.reset_cache_states() + -class BatchedDenseEmbedding(BaseBatchedEmbedding): +class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]): def __init__( self, config: GroupedEmbeddingConfig, @@ -434,18 +1161,29 @@ def __init__( super().__init__(config, pg, device) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) + use_cpu: bool = ( + device is None + or device.type == "cpu" + or (not (torch.cuda.is_available() or torch.mtia.is_available())) + ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), feature_table_map=self._feature_table_map, pooling_mode=PoolingMode.NONE, - use_cpu=device is None - or device.type == "cpu" - or not torch.cuda.is_available(), + use_cpu=use_cpu, weights_precision=weights_precision, + output_dtype=output_dtype, + use_mtia=device is not None and device.type == "mtia", + ) + ) + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( + _gen_named_parameters_by_table_dense( + self._emb_module, self.table_name_to_count.copy(), self._config ) ) - self.init_parameters() @property @@ -470,49 +1208,73 @@ def named_parameters( ) -class BaseBatchedEmbeddingBag(BaseEmbedding): +class BaseBatchedEmbeddingBag(BaseEmbedding, Generic[SplitWeightType]): def __init__( self, config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self._config = config self._pg = pg - self._pooling: PoolingMode = pooling_type_to_pooling_mode(config.pooling) + self._pooling: PoolingMode = pooling_type_to_pooling_mode( + config.pooling, sharding_type # pyre-ignore[6] + ) self._local_rows: List[int] = [] self._weight_init_mins: List[float] = [] self._weight_init_maxs: List[float] = [] self._num_embeddings: List[int] = [] + self._embedding_dims: List[int] = [] self._local_cols: List[int] = [] + self._row_offset: List[int] = [] + self._col_offset: List[int] = [] self._feature_table_map: List[int] = [] self._emb_names: List[str] = [] self._lengths_per_emb: List[int] = [] - - for idx, config in enumerate(self._config.embedding_tables): - self._local_rows.append(config.local_rows) - self._weight_init_mins.append(config.get_weight_init_min()) - self._weight_init_maxs.append(config.get_weight_init_max()) - self._num_embeddings.append(config.num_embeddings) - self._local_cols.append(config.local_cols) - self._feature_table_map.extend([idx] * config.num_features()) + self.table_name_to_count: Dict[str, int] = {} + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + + for idx, table_config in enumerate(self._config.embedding_tables): + self._local_rows.append(table_config.local_rows) + self._weight_init_mins.append(table_config.get_weight_init_min()) + self._weight_init_maxs.append(table_config.get_weight_init_max()) + self._num_embeddings.append(table_config.num_embeddings) + self._embedding_dims.append(table_config.embedding_dim) + self._row_offset.append( + table_config.local_metadata.shard_offsets[0] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 0 + else 0 + ) + self._col_offset.append( + table_config.local_metadata.shard_offsets[1] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 1 + else 0 + ) + self._local_cols.append(table_config.local_cols) + self._feature_table_map.extend([idx] * table_config.num_features()) + if table_config.name not in self.table_name_to_count: + self.table_name_to_count[table_config.name] = 0 + self.table_name_to_count[table_config.name] += 1 def init_parameters(self) -> None: # initialize embedding weights assert len(self._num_embeddings) == len(self.split_embedding_weights()) - for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + for rows, emb_dim, weight_init_min, weight_init_max, param in zip( self._local_rows, self._local_cols, self._weight_init_mins, self._weight_init_maxs, self.split_embedding_weights(), ): - assert param.shape == (rows, emb_dim) - param.data.uniform_( + assert param.shape == (rows, emb_dim) # pyre-ignore[16] + param.data.uniform_( # pyre-ignore[16] weight_init_min, weight_init_max, ) @@ -521,11 +1283,26 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: weights = features.weights_or_none() if weights is not None and not torch.is_floating_point(weights): weights = None - return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), - per_sample_weights=weights, - ) + if features.variable_stride_per_key() and isinstance( + self.emb_module, + ( + SplitTableBatchedEmbeddingBagsCodegen, + DenseTableBatchedEmbeddingBagsCodegen, + SSDTableBatchedEmbeddingBags, + ), + ): + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + per_sample_weights=weights, + batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), + ) + else: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + per_sample_weights=weights, + ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict( @@ -537,13 +1314,14 @@ def state_dict( self.flush() return get_state_dict( self._config.embedding_tables, + # pyre-ignore self.split_embedding_weights(), self._pg, destination, prefix, ) - def split_embedding_weights(self) -> List[torch.Tensor]: + def split_embedding_weights(self) -> List[SplitWeightType]: return self.emb_module.split_embedding_weights() @property @@ -554,8 +1332,7 @@ def emb_module( DenseTableBatchedEmbeddingBagsCodegen, SplitTableBatchedEmbeddingBagsCodegen, IntNBitTableBatchedEmbeddingBagsCodegen, - ]: - ... + ]: ... @property def config(self) -> GroupedEmbeddingConfig: @@ -564,28 +1341,211 @@ def config(self) -> GroupedEmbeddingConfig: def flush(self) -> None: pass + def purge(self) -> None: + pass + def named_split_embedding_weights( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: assert ( remove_duplicate ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" - for config, param in zip( + for config, tensor in zip( self._config.embedding_tables, self.emb_module.split_embedding_weights(), ): key = append_prefix(prefix, f"{config.name}.weight") - yield key, param + yield key, tensor + def named_parameters_by_table( + self, + ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + """ + Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. + For a single table with multiple shards (i.e CW) these are combined into one table/weight. + Used in composability. + """ + for name, param in self._param_per_table.items(): + yield name, param -class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule): + +class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule): def __init__( self, config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: - super().__init__(config, pg, device) + super().__init__(config, pg, device, sharding_type) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + ssd_tbe_params = _populate_ssd_tbe_params(config) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=self._pooling, + **ssd_tbe_params, + ).to(device) + + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd_pmt( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of SSD TBE is that we don't need to init weights. Hence + skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + no_snapshot: bool = True, + ) -> Dict[str, Any]: + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() + + emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + # pyre-ignore [6] + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + # pyre-ignore [15] + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + def get_named_split_embedding_weights_snapshot( + self, prefix: str = "" + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(no_snapshot=False), + ): + key = append_prefix(prefix, f"{config.name}") + yield key, tensor + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + # pyre-ignore [15] + def split_embedding_weights( + self, no_snapshot: bool = True + ) -> List[PartiallyMaterializedTensor]: + return self.emb_module.split_embedding_weights(no_snapshot) + + +class BatchedFusedEmbeddingBag( + BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule +): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, + ) -> None: + super().__init__(config, pg, device, sharding_type) managed: List[EmbeddingLocation] = [] compute_devices: List[ComputeDevice] = [] @@ -599,6 +1559,11 @@ def __init__( managed.append( compute_kernel_to_embedding_location(table.compute_kernel) ) + elif device is not None and device.type == "mtia": + compute_devices.append(ComputeDevice.MTIA) + # Set EmbeddingLocation.HOST to make embedding op in FBGEMM choose CPU path. + # But the tensor will still be created on MTIA with device type "mtia". + managed.append(EmbeddingLocation.HOST) else: compute_devices.append(ComputeDevice.CPU) managed.append(EmbeddingLocation.HOST) @@ -607,7 +1572,6 @@ def __init__( fused_params = config.fused_params or {} if "cache_precision" not in fused_params: fused_params["cache_precision"] = weights_precision - self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( @@ -617,6 +1581,15 @@ def __init__( pooling_mode=self._pooling, weights_precision=weights_precision, device=device, + table_names=[t.name for t in config.embedding_tables], + embedding_shard_info=list( + zip( + self._num_embeddings, + self._embedding_dims, + self._row_offset, + self._col_offset, + ) + ), **fused_params, ) ) @@ -625,7 +1598,14 @@ def __init__( self._emb_module, pg, ) - + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( + _gen_named_parameters_by_table_fused( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) self.init_parameters() @property @@ -645,39 +1625,63 @@ def named_buffers( By convention, fused parameters are designated as buffers because they no longer have gradients available to external optimizers. """ - return self.named_split_embedding_weights(prefix, recurse, remove_duplicate) + # TODO can delete this override once SEA is removed + yield from () def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, nn.Parameter]]: - yield from () + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param def flush(self) -> None: self._emb_module.flush() + def purge(self) -> None: + self._emb_module.reset_cache_states() + -class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag): +class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor]): def __init__( self, config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, ) -> None: - super().__init__(config, pg, device) + super().__init__(config, pg, device, sharding_type) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) + use_cpu: bool = ( + device is None + or device.type == "cpu" + or (not (torch.cuda.is_available() or torch.mtia.is_available())) + ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), feature_table_map=self._feature_table_map, pooling_mode=self._pooling, - use_cpu=device is None - or device.type == "cpu" - or not torch.cuda.is_available(), + use_cpu=use_cpu, weights_precision=weights_precision, + output_dtype=output_dtype, + use_mtia=device is not None and device.type == "mtia", + ) + ) + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( + _gen_named_parameters_by_table_dense( + self._emb_module, self.table_name_to_count.copy(), self._config ) ) - self.init_parameters() @property diff --git a/examples/datasets/tests/__init__.py b/torchrec/distributed/benchmark/__init__.py similarity index 100% rename from examples/datasets/tests/__init__.py rename to torchrec/distributed/benchmark/__init__.py diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py new file mode 100644 index 000000000..09e9ba10b --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import argparse +import logging +import os +import time +from functools import partial +from typing import List, Tuple + +import torch + +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark_module, + BenchmarkResult, + CompileMode, + DLRM_NUM_EMBEDDINGS_PER_FEATURE, + EMBEDDING_DIM, + get_tables, + init_argparse_and_args, + write_report, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.test_utils.infer_utils import ( + TestQuantEBCSharder, + TestQuantECSharder, +) +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, +) + +logger: logging.Logger = logging.getLogger() + + +BENCH_SHARDING_TYPES: List[ShardingType] = [ + ShardingType.TABLE_WISE, + ShardingType.ROW_WISE, + # ShardingType.COLUMN_WISE, + # TODO: CW with FXJIT takes long time while profiling, doesn't cause an issue with no profiling in automatic benchmark +] + +BENCH_COMPILE_MODES: List[CompileMode] = [ + CompileMode.EAGER, + CompileMode.FX_SCRIPT, +] + + +TABLE_SIZES: List[Tuple[int, int]] = [ + (num_embeddings, EMBEDDING_DIM) + for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE +] + +IGNORE_ARGNAME = ["output_dir", "embedding_config_json", "max_num_embeddings"] + + +def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES, is_pooled=False) + sharder = TestQuantECSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingCollection( + # pyre-ignore [6] + tables=tables, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=BENCH_SHARDING_TYPES, + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + **args_kwargs, + ) + + +def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES) + sharder = TestQuantEBCSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingBagCollection( + # pyre-ignore [6] + tables=tables, + is_weighted=False, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=BENCH_SHARDING_TYPES, + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + **args_kwargs, + ) + + +def benchmark_qec_unsharded( + args: argparse.Namespace, output_dir: str +) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES, is_pooled=False) + sharder = TestQuantECSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingCollection( + # pyre-ignore [6] + tables=tables, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=[], + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + benchmark_unsharded=True, # benchmark unsharded module + **args_kwargs, + ) + + +def benchmark_qebc_unsharded( + args: argparse.Namespace, output_dir: str +) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES) + sharder = TestQuantEBCSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingBagCollection( + # pyre-ignore [6] + tables=tables, + is_weighted=False, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=[], + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + benchmark_unsharded=True, # benchmark unsharded module + **args_kwargs, + ) + + +def main() -> None: + args: argparse.Namespace = init_argparse_and_args() + + num_requests = args.bench_iters * args.batch_size * args.num_benchmarks + datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S") + + output_dir = args.output_dir + if not os.path.exists(output_dir): + # Create output directory if not exist + os.mkdir(output_dir) + + benchmark_results_per_module = [] + write_report_funcs_per_module = [] + + module_names = [ + "QuantEmbeddingBagCollection", + "QuantEmbeddingCollection", + ] + + # Only do unsharded QEBC/QEC benchmark when using CPU device + if args.device_type == "cpu": + module_names.append("unshardedQuantEmbeddingBagCollection") + module_names.append("unshardedQuantEmbeddingCollection") + + for module_name in module_names: + output_dir = args.output_dir + f"/run_{datetime_sfx}" + if module_name == "QuantEmbeddingBagCollection": + output_dir += "_qebc" + benchmark_func = benchmark_qebc + elif module_name == "QuantEmbeddingCollection": + output_dir += "_qec" + benchmark_func = benchmark_qec + elif module_name == "unshardedQuantEmbeddingBagCollection": + output_dir += "_uqebc" + benchmark_func = benchmark_qebc_unsharded + else: + output_dir += "_uqec" + benchmark_func = benchmark_qec_unsharded + + if not os.path.exists(output_dir): + # Place all outputs under the datetime folder + os.mkdir(output_dir) + + tables_info = "\nTABLE SIZES QUANT:" + for i, (num, dim) in enumerate(TABLE_SIZES): + mb = int(float(num * dim) / 1024 / 1024) + tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb" + + report: str = ( + f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + ) + report += f"Module: {module_name}\n" + report += tables_info + report += "\n" + + report += f"num_requests:{num_requests:8}\n" + report_file: str = f"{output_dir}/run.report" + + # Save results to output them once benchmarking is all done + benchmark_results_per_module.append(benchmark_func(args, output_dir)) + write_report_funcs_per_module.append( + partial( + write_report, + report_file=report_file, + report_str=report, + num_requests=num_requests, + ) + ) + + for i, write_report_func in enumerate(write_report_funcs_per_module): + write_report_func(benchmark_results_per_module[i]) + + +def invoke_main() -> None: + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + + main() + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py new file mode 100644 index 000000000..dc393bcc0 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Any, List + +import click +import torch +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func +from torchrec.distributed.embedding import EmbeddingCollectionContext +from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def _set_sharding_context_post_a2a_previous( + kjts: List[KeyedJaggedTensor], + ctx: EmbeddingCollectionContext, +) -> None: + for kjt, sharding_context in zip(kjts, getattr(ctx, "sharding_contexts", [])): + if ( + hasattr(sharding_context, "batch_size_per_rank_per_feature") + and kjt.variable_stride_per_key() + and kjt.stride_per_key_per_rank() + ): + sharding_context.batch_size_per_rank_per_feature = [ + [ + kjt.stride_per_key_per_rank()[i][j] + for i in range(len(kjt.stride_per_key_per_rank())) + ] + for j in range(len(kjt.stride_per_key_per_rank()[0])) + ] + + +# buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_set_sharding_context_post_a2a -- --num_list=0 --num_keys=0 | grep set_sharding_context_post_a2a + + +@click.command() +@click.option("--num_list", default=100) +@click.option("--num_keys", default=100) +def main( + num_list: int, + num_keys: int, +) -> None: + if num_list == 0 and num_keys == 0: + for num_list in [100, 1000, 10000]: + for num_keys in [10, 100]: + op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous) + op_bench(num_list, num_keys, _set_sharding_context_post_a2a) + else: + op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous) + op_bench(num_list, num_keys, _set_sharding_context_post_a2a) + + +def op_bench( + num_list: int, + num_keys: int, + func_to_benchmark: Any, # pyre-ignore[2] +) -> None: + kjts = [ + KeyedJaggedTensor( + keys=["dummy_id"] * num_keys, + values=torch.IntTensor([1] * num_keys), + lengths=torch.IntTensor([1] * num_keys), + stride_per_key_per_rank=[[1]] * num_keys, + ) + for _ in range(num_list) + ] + for kjt in kjts: + kjt._variable_stride_per_key = True + ctx = EmbeddingCollectionContext( + sharding_contexts=[ + SequenceShardingContext(batch_size_per_rank_per_feature=[]) + for _ in range(num_list) + ] + ) + + bench_inputs = [] + + result = benchmark_func( + name=f"{func_to_benchmark.__name__}-{num_list}-{num_keys}", + bench_inputs=bench_inputs, + prof_inputs=bench_inputs, + num_benchmarks=10, + num_profiles=2, + profile_dir=".", + world_size=1, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs={"kjts": kjts, "ctx": ctx}, + rank=0, + pre_gpu_load=0, + device_type="cpu", + ) + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py new file mode 100644 index 000000000..8af1f9a46 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Dict, List + +import click + +import torch + +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType + +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, + EmbeddingLocation, + PoolingMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + SplitTableBatchedEmbeddingBagsCodegen, +) +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +@click.command() +@click.option("--num-embeddings", default=100_000) +@click.option("--embedding-dim", default=128) +@click.option("--num-tables", default=4) +@click.option("--batch-size", default=262144) +@click.option("--bag-size", default=10) +def main( + num_embeddings: int, + embedding_dim: int, + num_tables: int, + batch_size: int, + bag_size: int, +) -> None: + if embedding_dim == 0: + for embedding_dim in [4, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: + op_bench(num_embeddings, embedding_dim, num_tables, batch_size, bag_size) + else: + op_bench(num_embeddings, embedding_dim, num_tables, batch_size, bag_size) + + +def op_bench( + num_embeddings: int, + embedding_dim: int, + num_tables: int, + batch_size: int, + bag_size: int, +) -> None: + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + num_embeddings, + embedding_dim, + EmbeddingLocation.DEVICE, + ( + ComputeDevice.CUDA + if torch.cuda.is_available() + else ComputeDevice.CPU + ), + ) + ] + * num_tables, + optimizer=OptimType.EXACT_ADAGRAD, + learning_rate=0.1, + eps=0.1, + weights_precision=SparseType.FP32, + stochastic_rounding=False, + output_dtype=SparseType.FP32, + pooling_mode=PoolingMode.SUM, + bounds_check_mode=BoundsCheckMode.NONE, + ) + + def _func_to_benchmark( + kjts: List[Dict[str, KeyedJaggedTensor]], + model: torch.nn.Module, + ) -> torch.Tensor: + kjt = kjts[0]["feature"] + return model.forward(kjt.values(), kjt.offsets()) + + # breakpoint() # import fbvscode; fbvscode.set_trace() + tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_0", + feature_names=["feature_0"], + ) + ] + inputs = ModelInput.generate( + tables=tables, + weighted_tables=[], + batch_size=batch_size, + world_size=1, + num_float_features=0, + pooling_avg=10, + device=torch.device("cuda"), + )[0].idlist_features + + result = benchmark_func( + name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", + bench_inputs=[{"feature": inputs}], + prof_inputs=[{"feature": inputs}], + num_benchmarks=10, + num_profiles=10, + profile_dir=".", + world_size=1, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": emb}, + rank=0, + pre_gpu_load=3, + ) + + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py new file mode 100644 index 000000000..15ea780f2 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import argparse +import logging +import os +import time +from functools import partial +from typing import List, Optional, Tuple + +import torch + +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark_module, + BenchmarkResult, + CompileMode, + get_tables, + init_argparse_and_args, + set_embedding_config, + write_report, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.test_utils.test_model import TestEBCSharder +from torchrec.distributed.types import DataType +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +logger: logging.Logger = logging.getLogger() + + +BENCH_SHARDING_TYPES: List[ShardingType] = [ + ShardingType.TABLE_WISE, + ShardingType.ROW_WISE, + ShardingType.COLUMN_WISE, +] + +BENCH_COMPILE_MODES: List[CompileMode] = [ + CompileMode.EAGER, + # CompileMode.FX_SCRIPT, +] + + +def training_func_to_benchmark( + model: torch.nn.Module, + bench_inputs: List[KeyedJaggedTensor], + optimizer: Optional[torch.optim.Optimizer], +) -> None: + for bench_input in bench_inputs: + pooled_embeddings = model(bench_input) + vals = [] + for _name, param in pooled_embeddings.to_dict().items(): + vals.append(param) + torch.cat(vals, dim=1).sum().backward() + if optimizer: + optimizer.step() + optimizer.zero_grad() + + +def benchmark_ebc( + tables: List[Tuple[int, int]], + args: argparse.Namespace, + output_dir: str, + pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, +) -> List[BenchmarkResult]: + table_configs = get_tables(tables, data_type=DataType.FP32) + sharder = TestEBCSharder( + sharding_type="", # sharding_type gets populated during benchmarking + kernel_type=EmbeddingComputeKernel.FUSED.value, + ) + + # we initialize the embedding tables using CUDA, because when the table is large, + # CPU initialization will be prohibitively long. We then copy the module back + # to CPU because this module will be sent over subprocesses via multiprocessing, + # and we don't want to create an extra CUDA context on GPU0 for each subprocess. + # we also need to release the memory in the parent process (empty_cache) + module = EmbeddingBagCollection( + # pyre-ignore [6] + tables=table_configs, + is_weighted=False, + device=torch.device("cuda"), + ).cpu() + + torch.cuda.empty_cache() + + IGNORE_ARGNAME = ["output_dir", "embedding_config_json", "max_num_embeddings"] + optimizer = torch.optim.SGD(module.parameters(), lr=0.02) + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + if pooling_configs: + args_kwargs["pooling_configs"] = pooling_configs + + args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=BENCH_SHARDING_TYPES, + compile_modes=BENCH_COMPILE_MODES, + tables=table_configs, + output_dir=output_dir, + func_to_benchmark=training_func_to_benchmark, + benchmark_func_kwargs={"optimizer": optimizer}, + **args_kwargs, + ) + + +def main() -> None: + args: argparse.Namespace = init_argparse_and_args() + + num_requests = args.bench_iters * args.batch_size * args.num_benchmarks + datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S") + + output_dir = args.output_dir + if not os.path.exists(output_dir): + # Create output directory if not exist + os.mkdir(output_dir) + + benchmark_results_per_module = [] + write_report_funcs_per_module = [] + shrunk_table_sizes = [] + + embedding_configs, pooling_configs = set_embedding_config( + args.embedding_config_json + ) + for config in embedding_configs: + shrunk_table_sizes.append((min(args.max_num_embeddings, config[0]), config[1])) + + for module_name in ["EmbeddingBagCollection"]: + output_dir = args.output_dir + f"/run_{datetime_sfx}" + output_dir += "_ebc" + benchmark_func = benchmark_ebc + + if not os.path.exists(output_dir): + # Place all outputs under the datetime folder + os.mkdir(output_dir) + + tables_info = "\nTABLE SIZES:" + for i, (num, dim) in enumerate(shrunk_table_sizes): + # FP32 is 4 bytes + mb = int(float(num * dim) / 1024 / 1024) * 4 + tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb" + + ### Benchmark no VBE + report: str = ( + f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + ) + report += f"Module: {module_name}\n" + report += tables_info + report += "\n" + + report += f"num_requests:{num_requests:8}\n" + report_file: str = f"{output_dir}/run.report" + + # Save results to output them once benchmarking is all done + benchmark_results_per_module.append( + benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs) + ) + write_report_funcs_per_module.append( + partial( + write_report, + report_file=report_file, + report_str=report, + num_requests=num_requests, + ) + ) + + ### Benchmark with VBE + report: str = ( + f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + ) + report += f"Module: {module_name} (VBE)\n" + report += tables_info + report += "\n" + report_file = f"{output_dir}/run_vbe.report" + + benchmark_results_per_module.append( + benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True) + ) + write_report_funcs_per_module.append( + partial( + write_report, + report_file=report_file, + report_str=report, + num_requests=num_requests, + ) + ) + + for i, write_report_func in enumerate(write_report_funcs_per_module): + write_report_func(benchmark_results_per_module[i]) + + +def invoke_main() -> None: + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + + main() + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py new file mode 100644 index 000000000..c7d8e76e4 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import copy + +from dataclasses import dataclass +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union + +import click + +import torch +import torch.distributed as dist +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import nn, optim +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + run_multi_process_func, +) +from torchrec.distributed.test_utils.test_input import ( + ModelInput, + TdModelInput, + TestSparseNNInputConfig, +) +from torchrec.distributed.test_utils.test_model import ( + TestEBCSharder, + TestOverArchLarge, + TestSparseNN, +) +from torchrec.distributed.train_pipeline import ( + TrainPipeline, + TrainPipelineBase, + TrainPipelineSparseDist, +) +from torchrec.distributed.train_pipeline.train_pipelines import ( + PrefetchTrainPipelineSparseDist, + TrainPipelineSemiSync, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +@dataclass +class RunOptions: + world_size: int = 2 + num_batches: int = 10 + sharding_type: ShardingType = ShardingType.TABLE_WISE + input_type: str = "kjt" + profile: str = "" + + +@dataclass +class EmbeddingTablesConfig: + num_unweighted_features: int = 100 + num_weighted_features: int = 100 + embedding_feature_dim: int = 512 + + def generate_tables( + self, + ) -> Tuple[ + List[EmbeddingBagConfig], + List[EmbeddingBagConfig], + ]: + tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=self.embedding_feature_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(self.num_unweighted_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=self.embedding_feature_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(self.num_weighted_features) + ] + return tables, weighted_tables + + +@dataclass +class PipelineConfig: + pipeline: str = "base" + + def generate_pipeline( + self, model: nn.Module, opt: torch.optim.Optimizer, device: torch.device + ) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: + _pipeline_cls: Dict[ + str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] + ] = { + "base": TrainPipelineBase, + "sparse": TrainPipelineSparseDist, + "semi": TrainPipelineSemiSync, + "prefetch": PrefetchTrainPipelineSparseDist, + } + + if self.pipeline == "semi": + return TrainPipelineSemiSync( + model=model, optimizer=opt, device=device, start_batch=0 + ) + elif self.pipeline in _pipeline_cls: + Pipeline = _pipeline_cls[self.pipeline] + return Pipeline(model=model, optimizer=opt, device=device) + else: + raise RuntimeError(f"unknown pipeline option {self.pipeline}") + + +@click.command() +@cmd_conf(RunOptions, EmbeddingTablesConfig, TestSparseNNInputConfig, PipelineConfig) +def main( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + input_config: TestSparseNNInputConfig, + pipeline_config: PipelineConfig, +) -> None: + # sparse table config is available on each trainer + tables, weighted_tables = table_config.generate_tables() + + # launch trainers + run_multi_process_func( + func=runner, + world_size=run_option.world_size, + tables=tables, + weighted_tables=weighted_tables, + run_option=run_option, + input_config=input_config, + pipeline_config=pipeline_config, + ) + + +def _generate_data( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + input_config: TestSparseNNInputConfig, + num_batches: int, +) -> List[ModelInput]: + return [ + input_config.generate_model_input( + tables=tables, + weighted_tables=weighted_tables, + ) + for _ in range(num_batches) + ] + + +def _generate_model( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, +) -> nn.Module: + return TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + over_arch_clazz=TestOverArchLarge, + ) + + +def _generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Optional[Dict[str, Any]] = None, +) -> Tuple[nn.Module, Optimizer]: + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ).to(device) + optimizer = optim.SGD( + [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ], + lr=0.1, + ) + return sharded_model, optimizer + + +def runner( + rank: int, + world_size: int, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + run_option: RunOptions, + input_config: TestSparseNNInputConfig, + pipeline_config: PipelineConfig, +) -> None: + + torch.autograd.set_detect_anomaly(True) + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl", + use_deterministic_algorithms=False, + ) as ctx: + unsharded_model = _generate_model( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + + sharded_model, optimizer = _generate_sharded_model_and_optimizer( + model=unsharded_model, + sharding_type=run_option.sharding_type.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + # pyre-ignore + pg=ctx.pg, + device=ctx.device, + fused_params={ + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + }, + ) + bench_inputs = _generate_data( + tables=tables, + weighted_tables=weighted_tables, + input_config=input_config, + num_batches=run_option.num_batches, + ) + pipeline = pipeline_config.generate_pipeline( + sharded_model, optimizer, ctx.device + ) + pipeline.progress(iter(bench_inputs)) + + def _func_to_benchmark( + bench_inputs: List[ModelInput], + model: nn.Module, + pipeline: TrainPipeline, + ) -> None: + dataloader = iter(bench_inputs) + while True: + try: + pipeline.progress(dataloader) + except StopIteration: + break + + result = benchmark_func( + name=type(pipeline).__name__, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=run_option.profile, + world_size=run_option.world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, + rank=rank, + ) + if rank == 0: + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py new file mode 100644 index 000000000..c407c640f --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -0,0 +1,1227 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[16] + +#!/usr/bin/env python3 + +import argparse +import contextlib +import copy +import gc +import json +import logging +import os +import time +import timeit +from dataclasses import dataclass, fields, is_dataclass +from enum import Enum +from typing import ( + Any, + Callable, + ContextManager, + Dict, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) + +import click + +import torch +from torch import multiprocessing as mp +from torch.autograd.profiler import record_function +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import ShardingType +from torchrec.distributed.global_settings import set_propogate_device + +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.multi_process import MultiProcessContext +from torchrec.distributed.test_utils.test_model import ModelInput + +from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.test_utils import get_free_port + +logger: logging.Logger = logging.getLogger() + +# Reference: https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/README.MD +DLRM_NUM_EMBEDDINGS_PER_FEATURE = [ + 4833188, + 36746, + 17245, + 7413, + 20243, + 3, + 7114, + 1441, + 62, + 29275261, + 1572176, + 345138, + 10, + 2209, + 11267, + 128, + 4, + 974, + 14, + 48937457, + 11316796, + 40094537, + 452104, + 12606, + 104, + 35, +] + +EMBEDDING_DIM: int = 128 + + +class CompileMode(Enum): + EAGER = "eager" + FX_SCRIPT = "fx_script" + + +@dataclass +class MemoryStats: + rank: int + malloc_retries: int + max_mem_allocated_mbs: int + max_mem_reserved_mbs: int + + @classmethod + def for_device(cls, rank: int) -> "MemoryStats": + stats = torch.cuda.memory_stats(rank) + alloc_retries = stats.get("num_alloc_retries", 0) + max_allocated = stats.get("allocated_bytes.all.peak", 0) + max_reserved = stats.get("reserved_bytes.all.peak", 0) + return cls( + rank, + alloc_retries, + max_allocated // 1024 // 1024, + max_reserved // 1024 // 1024, + ) + + def __str__(self) -> str: + return f"Rank {self.rank}: retries={self.malloc_retries}, allocated={self.max_mem_allocated_mbs:7}mb, reserved={self.max_mem_reserved_mbs:7}mb" + + +@dataclass +class BenchmarkResult: + "Class for holding results of benchmark runs" + short_name: str + elapsed_time: torch.Tensor # milliseconds + mem_stats: List[MemoryStats] # memory stats per rank + rank: int = -1 + + def __str__(self) -> str: + runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms" + if len(self.mem_stats) == 0: + return f"{self.short_name: <{35}} | {runtime}" + mem_alloc = ( + f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB" + ) + mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2f} GB" + malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}" + return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}" + + def runtime_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return torch.quantile( + self.elapsed_time, + percentile / 100.0, + interpolation=interpolation, + ) + + def max_mem_alloc_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return self._mem_percentile( + lambda m: m.max_mem_allocated_mbs, percentile, interpolation + ) + + def max_mem_reserved_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return self._mem_percentile( + lambda m: m.max_mem_reserved_mbs, percentile, interpolation + ) + + def mem_retries( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return self._mem_percentile( + lambda m: m.malloc_retries, percentile, interpolation + ) + + def _mem_percentile( + self, + mem_selector: Callable[[MemoryStats], int], + percentile: int = 50, + interpolation: str = "nearest", + ) -> torch.Tensor: + mem_data = torch.tensor( + [mem_selector(mem_stat) for mem_stat in self.mem_stats], dtype=torch.float + ) + return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation) + + +class ECWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking EC Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ec.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qec = QuantEmbeddingCollection.from_float(ecc) + + wrapped_module = ECWrapper(qec) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + Dict[str, JaggedTensor] + """ + return self._module.forward(input) + + +class EBCWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + table_0 = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + table_1 = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ebc.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qebc = QuantEmbeddingBagCollection.from_float(ebc) + + wrapped_module = EBCWrapper(qebc) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + KeyedTensor + """ + return self._module.forward(input) + + +T = TypeVar("T", bound=torch.nn.Module) + + +def default_func_to_benchmark( + model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] +) -> None: + with torch.inference_mode(): + for bench_input in bench_inputs: + model(bench_input) + + +def get_tables( + table_sizes: List[Tuple[int, int]], + is_pooled: bool = True, + data_type: DataType = DataType.INT8, +) -> Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]: + if is_pooled: + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=data_type, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + else: + tables: List[EmbeddingConfig] = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=data_type, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + + return tables + + +def get_inputs( + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + batch_size: int, + world_size: int, + num_inputs: int, + train: bool, + pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, +) -> List[List[KeyedJaggedTensor]]: + inputs_batch: List[List[KeyedJaggedTensor]] = [] + + if variable_batch_embeddings and not train: + raise RuntimeError("Variable batch size is only supported in training mode") + + for _ in range(num_inputs): + if variable_batch_embeddings: + _, model_input_by_rank = ModelInput.generate_variable_batch_input( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + ) + else: + _, model_input_by_rank = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + weighted_tables=[], + tables_pooling=pooling_configs, + indices_dtype=torch.int32, + lengths_dtype=torch.int32, + ) + + if train: + sparse_features_by_rank = [ + model_input.idlist_features + for model_input in model_input_by_rank + if isinstance(model_input.idlist_features, KeyedJaggedTensor) + ] + inputs_batch.append(sparse_features_by_rank) + else: + sparse_features = model_input_by_rank[0].idlist_features + assert isinstance(sparse_features, KeyedJaggedTensor) + inputs_batch.append([sparse_features]) + + # Transpose if train, as inputs_by_rank is currently in [B X R] format + inputs_by_rank = [ + [sparse_features for sparse_features in sparse_features_rank] + for sparse_features_rank in zip(*inputs_batch) + ] + + return inputs_by_rank + + +def write_report( + benchmark_results: List[BenchmarkResult], + report_file: str, + report_str: str, + num_requests: int, +) -> None: + for benchmark_res in benchmark_results: + avg_dur_s = benchmark_res.elapsed_time.mean().item() * 1e-3 # time in seconds + std_dur_s = benchmark_res.elapsed_time.std().item() * 1e-3 # time in seconds + + qps = int(num_requests / avg_dur_s) + + mem_str = "" + for memory_stats in benchmark_res.mem_stats: + mem_str += f"{memory_stats}\n" + + report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}" + report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n" + report_str += f"\tMemory Allocated Per Rank:\n\t{mem_str}\n" + + with open(report_file, "w") as f: + f.write(report_str) + + logger.info(f"Report written to {report_file}:\n{report_str}") + + +def set_embedding_config( + embedding_config_json: str, +) -> Tuple[List[Tuple[int, int]], List[int]]: + """ + the config file should follow this pattern: {feature: {num_embeddings: int, embedding_dim: int}} + """ + embedding_configs = [] + pooling_configs = [] + has_pooling_config = False + try: + if os.path.exists(embedding_config_json): + with open(embedding_config_json, "r") as f: + embedding_config_json = json.load(f) + + for _, config in embedding_config_json.items(): + embedding_configs.append( + (config["num_embeddings"], config["embedding_dim"]) + ) + if "pooling_factor" in config: + pooling_configs.append(config["pooling_factor"]) + has_pooling_config = True + else: + if has_pooling_config: + raise RuntimeError( + "We cannot handle some features have pooling factor and others don't." + ) + else: + raise RuntimeError( + f"Could not find embedding config json at path {embedding_config_json}" + ) + except BaseException as e: + logger.warning( + f"Failed to load embedding config because {e}, fallback to DLRM config" + ) + embedding_configs = [ + (num_embeddings, EMBEDDING_DIM) + for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE + ] + + return embedding_configs, pooling_configs + + +# pyre-ignore [24] +def cmd_conf(*configs: Any) -> Callable: + support_classes: List[Any] = [int, str, bool, float, Enum] # pyre-ignore[33] + + # pyre-ignore [24] + def wrapper(func: Callable) -> Callable: + for config in configs: + assert is_dataclass(config), f"{config} should be a dataclass" + + # pyre-ignore + def rtf(**kwargs): + loglevel = logging._nameToLevel[kwargs["loglevel"].upper()] + logger.setLevel(logging.INFO) + input_configs = [] + for config in configs: + params = {} + for field in fields(config): + params[field.name] = kwargs.get(field.name, field.default) + conf = config(**params) + logger.info(conf) + input_configs.append(conf) + logger.setLevel(loglevel) + return func(*input_configs) + + names: Set[str] = set() + for config in configs: + for field in fields(config): + if not isinstance(field.default, tuple(support_classes)): + continue + if field.name not in names: + names.add(field.name) + else: + logger.warn(f"WARNING: duplicate argument {field.name}") + continue + rtf = click.option( + f"--{field.name}", type=field.type, default=field.default + )(rtf) + return click.option( + "--loglevel", + type=click.Choice(list(logging._nameToLevel.keys()), case_sensitive=False), + default=logging._levelToName[logger.level], + )(rtf) + + return wrapper + + +def init_argparse_and_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + parser.add_argument("--warmup_iters", type=int, default=20) + parser.add_argument("--bench_iters", type=int, default=500) + parser.add_argument("--prof_iters", type=int, default=20) + parser.add_argument("--batch_size", type=int, default=2048) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--max_num_embeddings", type=int, default=1000000) + parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench") + parser.add_argument("--num_benchmarks", type=int, default=5) + parser.add_argument("--embedding_config_json", type=str, default="") + parser.add_argument("--device_type", type=str, default="cuda") + + args = parser.parse_args() + + return args + + +def transform_module( + module: torch.nn.Module, + device: torch.device, + inputs: List[KeyedJaggedTensor], + sharder: ModuleSharder[T], + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, + # pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter. + ctx: ContextManager, + benchmark_unsharded_module: bool = False, +) -> torch.nn.Module: + def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: + eager_module(inputs[0]) + graph_module = symbolic_trace( + eager_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] + ) + scripted_module = torch.jit.script(graph_module) + return scripted_module + + set_propogate_device(True) + + sharded_module = None + + if not benchmark_unsharded_module: + topology: Topology = Topology(world_size=world_size, compute_device=device.type) + planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + # Don't want to modify the module outright + # Since module is on cpu, won't cause cuda oom. + copied_module = copy.deepcopy(module) + # pyre-ignore [6] + plan = planner.plan(copied_module, [sharder]) + + if isinstance(ctx, MultiProcessContext): + sharded_module = DistributedModelParallel( + copied_module, + # pyre-ignore[6] + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + else: + env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) + + sharded_module = _shard_modules( + module=copied_module, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[ModuleSharder[Variable[T (bound to Module)]]]`. + sharders=[sharder], + device=device, + plan=plan, + env=env, + ) + + if compile_mode == CompileMode.FX_SCRIPT: + return fx_script_module( + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Optional[Module]`. + sharded_module + if not benchmark_unsharded_module + else module + ) + else: + # pyre-fixme[7]: Expected `Module` but got `Optional[Module]`. + return sharded_module if not benchmark_unsharded_module else module + + +def benchmark( + name: str, + model: torch.nn.Module, + warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + world_size: int, + output_dir: str, + num_benchmarks: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int, + enable_logging: bool = True, + device_type: str = "cuda", + benchmark_unsharded_module: bool = False, +) -> BenchmarkResult: + memory_stats: List[MemoryStats] = [] + if enable_logging: + logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") + + for _input in warmup_inputs: + model(_input) + + if device_type == "cuda": + if rank == -1: + # Reset memory for measurement, no process per rank so do all + for di in range(world_size): + torch.cuda.reset_peak_memory_stats(di) + else: + torch.cuda.reset_peak_memory_stats(rank) + + start = [] + end = [] + if device_type == "cuda": + # Measure time taken for batches in bench_inputs + start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + + if benchmark_func_kwargs is None: + # Need this to unwrap + benchmark_func_kwargs = {} + + times = [] + if device_type == "cuda": + for i in range(num_benchmarks): + start[i].record() + func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs) + end[i].record() + elif device_type == "cpu": + times = timeit.repeat( + lambda: func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) + + if device_type == "cuda": + if rank == -1: + for di in range(world_size): + torch.cuda.synchronize(di) + else: + torch.cuda.synchronize(rank) + + # TODO: First Benchmark Run for Eager Mode produces outlier + # Start counting after first as workaround for standard deviation + if device_type == "cuda": + elapsed_time = torch.tensor( + [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + ) + else: + elapsed_time = torch.tensor(times) * 1e3 + + if device_type == "cuda": + if rank == -1: + # Add up all memory allocated in inference mode + for di in range(world_size): + memory_stats.append(MemoryStats.for_device(di)) + else: + # Only add up memory allocated for current rank in training mode + memory_stats.append(MemoryStats.for_device(rank)) + + if output_dir != "": + # Only do profiling if output_dir is set + + # pyre-ignore[2] + def trace_handler(prof) -> None: + total_average = prof.profiler.total_average() + logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") + dir_path: str = output_dir + + # only 1 rank should output in pg case, rank = 0 + if rank > 0: + return + + trace_file: str = f"{dir_path}/trace-{name}.json" + stacks_cpu_file = f"{dir_path}/stacks-cpu-{name}.stacks" + stacks_cuda_file = f"{dir_path}/stacks-cuda-{name}.stacks" + logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") + + prof.export_chrome_trace(trace_file) + prof.export_stacks(stacks_cpu_file, "self_cpu_time_total") + prof.export_stacks(stacks_cuda_file, "self_cuda_time_total") + + # - git clone https://github.com/brendangregg/FlameGraph + # - cd FlameGraph + # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg + + if device_type == "cuda": + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + with_modules=True, + on_trace_ready=trace_handler, + ) as p: + for _input in prof_inputs: + with record_function("## forward ##"): + model(_input) + p.step() + + if rank == -1: + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + else: + torch.cuda.synchronize() + + return BenchmarkResult( + short_name=name, + elapsed_time=elapsed_time, + mem_stats=memory_stats, + rank=rank, + ) + + +def benchmark_func( + name: str, + bench_inputs: List[Dict[str, Any]], + prof_inputs: List[Dict[str, Any]], + world_size: int, + profile_dir: str, + num_benchmarks: int, + num_profiles: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int, + device_type: str = "cuda", + pre_gpu_load: int = 0, +) -> BenchmarkResult: + memory_stats: List[MemoryStats] = [] + if device_type == "cuda": + if rank == -1: + # Reset memory for measurement, no process per rank so do all + for di in range(world_size): + torch.cuda.reset_peak_memory_stats(di) + torch.cuda.reset_accumulated_memory_stats(di) + else: + torch.cuda.reset_peak_memory_stats(rank) + torch.cuda.reset_accumulated_memory_stats(rank) + + start = [] + end = [] + if device_type == "cuda": + # Measure time taken for batches in bench_inputs + start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + + if benchmark_func_kwargs is None: + # Need this to unwrap + benchmark_func_kwargs = {} + + times = [] + if device_type == "cuda": + a = torch.rand(16384, 16384, device="cuda") + for _ in range(pre_gpu_load): + a = a * torch.rand(16384, 16384, device="cuda") + for i in range(num_benchmarks): + start[i].record() + func_to_benchmark(bench_inputs, **benchmark_func_kwargs) + end[i].record() + elif device_type == "cpu": + if bench_inputs is None or len(bench_inputs) == 0: + times = timeit.repeat( + lambda: func_to_benchmark(**benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) + else: + times = timeit.repeat( + lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) + + if device_type == "cuda": + if rank == -1: + for di in range(world_size): + torch.cuda.synchronize(di) + else: + torch.cuda.synchronize(rank) + + # TODO: First Benchmark Run for Eager Mode produces outlier + # Start counting after first as workaround for standard deviation + if device_type == "cuda": + elapsed_time = torch.tensor( + [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + ) + else: + elapsed_time = torch.tensor(times) * 1e3 + + if device_type == "cuda": + if rank == -1: + # Add up all memory allocated in inference mode + for di in range(world_size): + memory_stats.append(MemoryStats.for_device(di)) + else: + # Only add up memory allocated for current rank in training mode + memory_stats.append(MemoryStats.for_device(rank)) + + if profile_dir != "": + # Only do profiling if output_dir is set + + # pyre-ignore[2] + def trace_handler(prof) -> None: + total_average = prof.profiler.total_average() + logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") + dir_path: str = profile_dir + if rank == 0: + trace_file: str = f"{dir_path}/trace-{name}.json" + else: + trace_file: str = f"{dir_path}/trace-{name}-{rank}.json" + return # only 1 rank should output in pg case, rank = 0 + logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") + prof.export_chrome_trace(trace_file) + + if device_type == "cuda": + a = torch.rand(16384, 16384, device="cuda") + for _ in range(pre_gpu_load): + a = a * torch.rand(16384, 16384, device="cuda") + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_flops=True, + with_modules=True, + with_stack=False, # usually we don't want to show the entire stack in the trace + on_trace_ready=trace_handler, + ) as p: + for i in range(num_profiles): + with record_function(f"## profile {i} ##"): + func_to_benchmark(prof_inputs, **benchmark_func_kwargs) + p.step() + + if rank == -1: + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + else: + torch.cuda.synchronize() + + return BenchmarkResult( + short_name=name, + elapsed_time=elapsed_time, + mem_stats=memory_stats, + rank=rank, + ) + + +def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: + if sharding_type == ShardingType.TABLE_WISE: + name = "tw-sharded" + elif sharding_type == ShardingType.ROW_WISE: + name = "rw-sharded" + elif sharding_type == ShardingType.COLUMN_WISE: + name = "cw-sharded" + else: + raise Exception(f"Unknown sharding type {sharding_type}") + + if compile_mode == CompileMode.EAGER: + name += "-eager" + elif compile_mode == CompileMode.FX_SCRIPT: + name += "-fxjit" + + return name + + +def init_module_and_run_benchmark( + module: torch.nn.Module, + sharder: ModuleSharder[T], + device: torch.device, + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, + warmup_inputs: List[List[KeyedJaggedTensor]], + bench_inputs: List[List[KeyedJaggedTensor]], + prof_inputs: List[List[KeyedJaggedTensor]], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + output_dir: str, + num_benchmarks: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int = -1, + queue: Optional[mp.Queue] = None, + pooling_configs: Optional[List[int]] = None, + benchmark_unsharded_module: bool = False, +) -> BenchmarkResult: + """ + There are a couple of caveats here as to why the module has to be initialized + here: + 1. Device. To accurately track memory usage, when sharding modules the initial + placement of the module should be on CPU. This is to avoid double counting + memory allocations and also to prevent CUDA OOMs. + 2. Garbage Collector. Since torch.fx.GraphModule has circular references, + garbage collection us funky and can lead to ooms. Since this frame is + called by the loop through compile modes and sharding types, returning the + benchmark result will mean that the reference to module is lost instead of + existing in the loop + """ + + if rank >= 0: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device.type}:{rank}")) + for warmup_input in warmup_inputs[rank] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device.type}:{rank}")) + for bench_input in bench_inputs[rank] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device.type}:{rank}")) + for prof_input in prof_inputs[rank] + ] + else: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device.type}:0")) + for warmup_input in warmup_inputs[0] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device.type}:0")) + for bench_input in bench_inputs[0] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device.type}:0")) + for prof_input in prof_inputs[0] + ] + + with ( + MultiProcessContext(rank, world_size, "nccl", None) + if rank != -1 + else contextlib.nullcontext() + ) as ctx: + module = transform_module( + module=module, + device=device, + inputs=warmup_inputs_cuda, + sharder=sharder, + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + # pyre-ignore[6] + ctx=ctx, + benchmark_unsharded_module=benchmark_unsharded_module, + ) + + if benchmark_unsharded_module: + name = "unsharded" + compile_mode.name + else: + name = benchmark_type_name(compile_mode, sharding_type) + + res = benchmark( + name, + module, + warmup_inputs_cuda, + bench_inputs_cuda, + prof_inputs_cuda, + world_size=world_size, + output_dir=output_dir, + num_benchmarks=num_benchmarks, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + rank=rank, + device_type=device.type, + benchmark_unsharded_module=benchmark_unsharded_module, + ) + + if queue is not None: + queue.put(res) + + while not queue.empty(): + time.sleep(1) + + return res + + +def multi_process_benchmark( + callable: Callable[ + ..., + None, + ], + # pyre-ignore + **kwargs, +) -> BenchmarkResult: + + def setUp() -> None: + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + + assert "world_size" in kwargs + world_size = kwargs["world_size"] + + setUp() + benchmark_res_per_rank = [] + # kineto has a known problem with fork-server: it'll hang + # when dumping the trace. Workaround with spawn + ctx = mp.get_context("spawn") + qq = ctx.SimpleQueue() + processes = [] + + for rank in range(world_size): + kwargs["rank"] = rank + kwargs["world_size"] = world_size + kwargs["queue"] = qq + p = ctx.Process( + target=callable, + kwargs=kwargs, + ) + p.start() + processes.append(p) + + for _ in range(world_size): + res = qq.get() + + benchmark_res_per_rank.append(res) + assert len(res.mem_stats) == 1 + + for p in processes: + p.join() + assert 0 == p.exitcode + + total_benchmark_res = BenchmarkResult( + benchmark_res_per_rank[0].short_name, + benchmark_res_per_rank[0].elapsed_time, + [MemoryStats(rank, 0, 0, 0) for rank in range(world_size)], + 0, + ) + + for res in benchmark_res_per_rank: + # Each rank's BenchmarkResult contains 1 memory measurement + total_benchmark_res.mem_stats[res.rank] = res.mem_stats[0] + + return total_benchmark_res + + +def benchmark_module( + module: torch.nn.Module, + sharder: ModuleSharder[T], + sharding_types: List[ShardingType], + compile_modes: List[CompileMode], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + warmup_iters: int = 20, + bench_iters: int = 500, + prof_iters: int = 20, + batch_size: int = 2048, + world_size: int = 2, + num_benchmarks: int = 5, + output_dir: str = "", + benchmark_unsharded: bool = False, + func_to_benchmark: Callable[..., None] = default_func_to_benchmark, + benchmark_func_kwargs: Optional[Dict[str, Any]] = None, + pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, + device_type: str = "cuda", +) -> List[BenchmarkResult]: + """ + Args: + eager_module: Eager mode module to be benchmarked + sharding_types: Sharding types to be benchmarked + compile_modes: Compilation modes to be benchmarked + warmup_iters: Number of iterations to run before profiling + bench_iters: Number of iterations to run during profiling + prof_iters: Number of iterations to run after profiling + batch_size: Batch size used in the model + world_size: World size used in the + num_benchmarks: How many times to run over benchmark inputs for statistics + output_dir: Directory to output profiler outputs (traces, stacks) + pooling_configs: The pooling factor for the tables. + (Optional; if not set, we'll use 10 as default) + func_to_benchmark: Custom function to benchmark, check out default_func_to_benchmark for default + benchmark_func_kwargs: Custom keyword arguments to pass to func_to_benchmark + + Returns: + A list of BenchmarkResults + """ + + # logging.info(f"###### Benchmarking Module: {eager_module} ######\n") + logging.info(f"Warmup iterations: {warmup_iters}") + logging.info(f"Benchmark iterations: {bench_iters}") + logging.info(f"Profile iterations: {prof_iters}") + logging.info(f"Batch Size: {batch_size}") + logging.info(f"World Size: {world_size}") + logging.info(f"Number of Benchmarks: {num_benchmarks}") + logging.info(f"Output Directory: {output_dir}") + + assert ( + num_benchmarks > 2 + ), "num_benchmarks needs to be greater than 2 for statistical analysis" + if isinstance(module, QuantEmbeddingBagCollection) or isinstance( + module, QuantEmbeddingCollection + ): + train = False + else: + train = True + + benchmark_results: List[BenchmarkResult] = [] + + if isinstance(tables[0], EmbeddingBagConfig): + wrapped_module = EBCWrapper(module) + else: + wrapped_module = ECWrapper(module) + + num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters + inputs = get_inputs( + tables, + batch_size, + world_size, + num_inputs_to_gen, + train, + pooling_configs, + variable_batch_embeddings, + ) + + warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs] + bench_inputs = [ + rank_inputs[warmup_iters : (warmup_iters + bench_iters)] + for rank_inputs in inputs + ] + prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs] + + for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]: + for compile_mode in compile_modes: + if not benchmark_unsharded: + # Test sharders should have a singular sharding_type + sharder._sharding_type = sharding_type.value + # pyre-ignore [6] + benchmark_type = benchmark_type_name(compile_mode, sharding_type) + else: + benchmark_type = "unsharded" + compile_mode.name + + logging.info( + f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n" + ) + + if train: + res = multi_process_benchmark( + # pyre-ignore[6] + callable=init_module_and_run_benchmark, + module=wrapped_module, + sharder=sharder, + device=torch.device(device_type), + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + warmup_inputs=warmup_inputs, + bench_inputs=bench_inputs, + prof_inputs=prof_inputs, + tables=tables, + num_benchmarks=num_benchmarks, + output_dir=output_dir, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + pooling_configs=pooling_configs, + ) + else: + res = init_module_and_run_benchmark( + module=wrapped_module, + sharder=sharder, + device=torch.device(device_type), + # pyre-ignore + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + warmup_inputs=warmup_inputs, + bench_inputs=bench_inputs, + prof_inputs=prof_inputs, + tables=tables, + num_benchmarks=num_benchmarks, + output_dir=output_dir, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + pooling_configs=pooling_configs, + benchmark_unsharded_module=benchmark_unsharded, + ) + + gc.collect() + + benchmark_results.append(res) + + return benchmark_results diff --git a/torchrec/distributed/collective_utils.py b/torchrec/distributed/collective_utils.py index a16351d49..d3cad0fbf 100644 --- a/torchrec/distributed/collective_utils.py +++ b/torchrec/distributed/collective_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """ This file contains utilities for constructing collective based control flows. """ diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py index 8082b2d38..61ffc4f48 100644 --- a/torchrec/distributed/comm.py +++ b/torchrec/distributed/comm.py @@ -5,12 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging import os from typing import List, Optional, Tuple import torch import torch.distributed as dist +from torchrec.distributed.types import ShardingEnv2D logger: logging.Logger = logging.getLogger(__name__) @@ -18,6 +21,11 @@ _INTRA_PG: Optional[dist.ProcessGroup] = None _CROSS_PG: Optional[dist.ProcessGroup] = None +# For 2D parallel +_INTRA_PG_2D: Optional[dist.ProcessGroup] = None +_CROSS_PG_2D: Optional[dist.ProcessGroup] = None +_NODE_GROUP_SIZE_2D: Optional[int] = None + def _env2int(env_list: List[str], default: int = -1) -> int: for e in env_list: @@ -52,6 +60,15 @@ def get_local_size(world_size: Optional[int] = None) -> int: return local_size +def get_node_group_size(world_size: Optional[int] = None) -> int: + """ + Get the local world size accounting for 2D environment, if not set, we fallback to global environment + """ + if _NODE_GROUP_SIZE_2D is None: + return get_local_size(world_size) + return _NODE_GROUP_SIZE_2D + + def get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) -> int: """ Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) @@ -100,7 +117,7 @@ def get_num_groups(world_size: Optional[int] = None) -> int: def intra_and_cross_node_pg( device: Optional[torch.device] = None, - backend: str = "nccl", + backend: Optional[str] = None, ) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: """ Creates sub process groups (intra and cross node) @@ -117,14 +134,8 @@ def intra_and_cross_node_pg( local_size = get_local_size(my_size) my_group_rank = get_group_rank(my_size, my_rank) group_count = get_num_groups(my_size) - my_backend = dist.get_backend() - - if my_backend != backend: - logger.warn( - f"global PG is initialized with backend {my_backend}, while trying to perform intra_and_cross_node_pg with backend {backend}, " - f"use the global backend {my_backend} to proceed" - ) - backend = my_backend + if backend is None: + backend = dist.get_backend() logger.info( f"[{my_rank}] my_local_rank = {my_local_rank}, local_size = {local_size}," @@ -139,8 +150,8 @@ def intra_and_cross_node_pg( "[Connection] intra_group: [%d] -> [%s]" % (my_rank, peers) ) _INTRA_PG = curr_intra_group_pg - - dist.barrier() + assert _INTRA_PG is not None + dist.barrier() if _CROSS_PG is None: for l_rank in range(local_size): @@ -151,7 +162,115 @@ def intra_and_cross_node_pg( "[Connection] cross_group: [%d] -> [%s]" % (my_rank, peers) ) _CROSS_PG = curr_cross_group_pg - - dist.barrier() + assert _CROSS_PG is not None + dist.barrier() return _INTRA_PG, _CROSS_PG + + +def intra_and_cross_node_pg_2D( + env: ShardingEnv2D, + device: Optional[torch.device] = None, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """ + Creates sub process groups (intra and cross node) under 2D parallelism scheme + The concept of "intra" and "cross" node is lost under a 2D parallelism scheme + due to the ranks that exist under a sharding group do not have gurantee of the typical + node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth. + + NOTE: + These process groups are created for sharding schemes (ie: GRID) that were designed to exploit + intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID + sharding to be optimized under a 2D setup. + + Example:: + Here is what "intra" and "cross" groups look like in a 2D environment, + Sharding Groups: + Group 0: [0, 2, 4, 6] + Group 1: [1, 3, 5, 7] + devices_per_node = 2: + "intra" groups for each sharding group, + Group 0: [0, 2], [4, 6] + Group 1: [1, 3], [5, 7] + "cross" groups for each sharding group, + Group 0: [0, 4], [2, 6] + Group 1: [1, 5], [3, 7] + + We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional + sense are not applicable here. + """ + if device is not None and device.type == "meta": + return None, None + + global _INTRA_PG_2D + global _CROSS_PG_2D + global _NODE_GROUP_SIZE_2D + + backend = dist.get_backend(env.sharding_pg) + my_rank = dist.get_rank() + + sharding_group_size = dist.get_world_size( + env.sharding_pg + ) # Local replica group world size + world_size = dist.get_world_size() # Global world size + step = world_size // sharding_group_size + devices_per_node = ( + env.node_group_size if env.node_group_size else get_local_size(world_size) + ) + _NODE_GROUP_SIZE_2D = devices_per_node + + assert ( + sharding_group_size % devices_per_node == 0 + ), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}" + intra_pg_groups: List[List[List[int]]] = [[] for _ in range(step)] + + if _INTRA_PG_2D is None: + for group_rank in range(step): + if env.use_inter_host_allreduce: + # for inter host all reduce, we change the sharding group calculation to be continuous + ranks = group_rank * sharding_group_size + sharding_pg_peers = list(range(ranks, ranks + sharding_group_size)) + else: + sharding_pg_peers = [ + step * r + group_rank for r in range(sharding_group_size) + ] + for group in range(len(sharding_pg_peers) // devices_per_node): + intra_pg_peers = sharding_pg_peers[ + group * devices_per_node : (group + 1) * devices_per_node + ] + intra_pg_groups[group_rank].append(intra_pg_peers) + curr_intra_pg = dist.new_group( + backend=backend, + ranks=intra_pg_peers, + group_desc="sharding_intra_pg", + ) + if my_rank in intra_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}" + ) + _INTRA_PG_2D = curr_intra_pg + assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!" + dist.barrier() + + if _CROSS_PG_2D is None: + for group_rank in range(step): + intra_pg_group = intra_pg_groups[group_rank] + for cross_group_rank in range(devices_per_node): + cross_pg_peers = [ + intra_pg_group[j][cross_group_rank] + for j in range(len(intra_pg_group)) + ] + curr_cross_pg = dist.new_group( + backend=backend, + ranks=cross_pg_peers, + group_desc="sharding_cross_pg", + ) + if my_rank in cross_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}" + ) + _CROSS_PG_2D = curr_cross_pg + assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!" + dist.barrier() + + return _INTRA_PG_2D, _CROSS_PG_2D diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 31859e36d..4d950c7e9 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -5,17 +5,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, List, Optional, Tuple, TypeVar import torch import torch.distributed as dist +import torch.distributed._functional_collectives from torch import Tensor from torch.autograd import Function from torch.autograd.profiler import record_function from torchrec.distributed.types import Awaitable, NoWait, QuantizedCommCodecs from torchrec.distributed.utils import none_throws +from torchrec.pt2.checks import is_torchdynamo_compiling try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -26,7 +31,7 @@ # OSS try: - import fbgemm_gpu # @manual # noqa + pass except ImportError: pass @@ -35,6 +40,7 @@ # TODO: T96382816, NE Parity Backward compatibility GRADIENT_DIVISION: bool = True +USE_SYNC_COLLECTIVES: bool = False def set_gradient_division(val: bool) -> None: @@ -42,6 +48,35 @@ def set_gradient_division(val: bool) -> None: GRADIENT_DIVISION = val +def get_gradient_division() -> bool: + global GRADIENT_DIVISION + return GRADIENT_DIVISION + + +def set_use_sync_collectives(val: bool) -> None: + if val and torch._running_with_deploy(): + raise RuntimeError( + "TorchRec sync_collectives are not supported in torch.deploy." + ) + + global USE_SYNC_COLLECTIVES + USE_SYNC_COLLECTIVES = val + + +def get_use_sync_collectives() -> bool: + global USE_SYNC_COLLECTIVES + return USE_SYNC_COLLECTIVES or is_torchdynamo_compiling() + + +@contextmanager +# pyre-ignore +def torchrec_use_sync_collectives(): + original_use_sync_collectives: bool = get_use_sync_collectives() + set_use_sync_collectives(True) + yield + set_use_sync_collectives(original_use_sync_collectives) + + """ Some commonly used notations for comm ops: B - batch size @@ -72,10 +107,13 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device) -> None: # This dummy tensor is used to build the autograd graph between # CommOp-Req and CommOp-Await. The actual forward tensors, and backwards gradient tensors # are stored in self.tensor - self.dummy_tensor: torch.Tensor = torch.empty( - 1, - requires_grad=True, - device=device, + # torch.zeros is a call_function, not placeholder, hence fx.trace incompatible. + self.dummy_tensor: torch.Tensor = torch.zeros_like( + torch.empty( + 1, + requires_grad=True, + device=device, + ) ) def _wait_impl(self) -> W: @@ -84,6 +122,8 @@ def _wait_impl(self) -> W: """ ret = self.wait_function.apply(self.pg, self, self.dummy_tensor) + if isinstance(ret, torch.Tensor) and ret.device.type == "cuda": + ret.record_stream(torch.get_device_module(ret.device).current_stream()) self.req = None self.tensor = None return ret @@ -105,7 +145,7 @@ class All2AllPooledInfo(object): cumsum_dim_sum_per_rank_tensor (Optional[Tensor]): cumulative sum of `dim_sum_per_rank`, this is only used by the fast kernel of `_recat_pooled_embedding_grad_out`. - B_local (int): local batch size before scattering. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. """ batch_size_per_rank: List[int] @@ -115,6 +155,31 @@ class All2AllPooledInfo(object): codecs: Optional[QuantizedCommCodecs] = None +@dataclass +class VariableBatchAll2AllPooledInfo(object): + """ + The data class that collects the attributes when calling the + `variable_batch_alltoall_pooled` operation. + + Attributes: + batch_size_per_rank_per_feature (List[List[int]]): batch size per rank per + feature. + batch_size_per_feature_pre_a2a (List[int]): local batch size before scattering. + emb_dim_per_rank_per_feature (List[List[int]]): embedding dimension per rank + per feature + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. + input_splits (Optional[List[int]]): input splits of tensor all to all. + output_splits (Optional[List[int]]): output splits of tensor all to all. + """ + + batch_size_per_rank_per_feature: List[List[int]] + batch_size_per_feature_pre_a2a: List[int] + emb_dim_per_rank_per_feature: List[List[int]] + codecs: Optional[QuantizedCommCodecs] = None + input_splits: Optional[List[int]] = None + output_splits: Optional[List[int]] = None + + @dataclass class All2AllSequenceInfo(object): """ @@ -129,7 +194,8 @@ class All2AllSequenceInfo(object): backward_recat_tensor (Tensor): recat tensor for backward. input_splits (List[int]): input splits. output_splits (List[int]): output splits. - variable_batch_size (bool): whether variable batch size is enabled + variable_batch_size (bool): whether variable batch size is enabled. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. permuted_lengths_after_sparse_data_all2all (Optional[Tensor]): lengths of sparse features before AlltoAll. """ @@ -194,8 +260,8 @@ class ReduceScatterInfo(object): @dataclass class ReduceScatterBaseInfo(object): """ - The data class that collects the attributes when calling the `reduce_scatter_base_pooled` - operation. + The data class that collects the attributes when calling the + `reduce_scatter_base_pooled` operation. Attributes: input_sizes (torch.Size): the sizes of the input flatten tensor. @@ -208,8 +274,8 @@ class ReduceScatterBaseInfo(object): @dataclass class AllGatherBaseInfo(object): """ - The data class that collects the attributes when calling the `all_gatther_base_pooled` - operation. + The data class that collects the attributes when calling the + `all_gatther_base_pooled` operation. Attributes: input_size (int): the size of the input tensor. @@ -226,14 +292,16 @@ class ReduceScatterVInfo(object): operation. Attributes: - input_sizes (List[torch.Size]): the sizes of the input tensors. This remembers the + input_sizes (List[List[int]]): the sizes of the input tensors. This saves the sizes of the input tensors when running the backward pass and producing the gradient. - input_splits (List[int]): the splits of the input tensors along dim0. - total_input_size: (List[int]): total input size + input_splits (List[int]): the splits of the input tensors along dim 0. + equal_splits (bool): ... + total_input_size: (List[int]): total input size. + codecs (Optional[QuantizedCommCodecs]): ... """ - input_sizes: List[torch.Size] + input_sizes: List[List[int]] input_splits: List[int] equal_splits: bool total_input_size: List[int] @@ -243,7 +311,7 @@ class ReduceScatterVInfo(object): @dataclass class All2AllDenseInfo(object): """ - The data class that collects the attributes when calling the alltoall_dense + The data class that collects the attributes when calling the `alltoall_dense` operation. """ @@ -256,12 +324,17 @@ class All2AllDenseInfo(object): def _get_split_lengths_by_len( world_size: int, my_rank: int, n: int ) -> Tuple[int, List[int]]: - k, m = divmod(n, world_size) + k = n // world_size + m = n % world_size + splits = [] if m == 0: - splits = [k] * world_size + for _ in range(world_size): + splits.append(k) + my_len = k else: - splits = [(k + 1) if i < m else k for i in range(world_size)] + for i in range(world_size): + splits.append((k + 1) if i < m else k) my_len = splits[my_rank] return (my_len, splits) @@ -283,9 +356,8 @@ def alltoall_pooled( Args: a2a_pooled_embs_tensor (Tensor): input pooled embeddings. Must be pooled - together before passing into this function. Its shape is B x D_local_sum, - where D_local_sum is the dimension sum of all the local - embedding tables. + together before passing into this function. Its shape is `B x D_local_sum`, + where `D_local_sum` is the dimension sum of all the local embedding tables. batch_size_per_rank (List[int]): batch size in each rank. dim_sum_per_rank (List[int]): number of features (sum of dimensions) of the embedding in each rank. @@ -295,12 +367,12 @@ def alltoall_pooled( cumsum_dim_sum_per_rank_tensor (Optional[Tensor]): cumulative sum of `dim_sum_per_rank`, this is only used by the fast kernel of `_recat_pooled_embedding_grad_out`. - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. - codecs: Optional[QuantizedCommCodecs]: Quantized communication codecs + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Returns: - Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor. + Awaitable[Tensor]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor. .. warning:: `alltoall_pooled` is experimental and subject to change. @@ -309,10 +381,9 @@ def alltoall_pooled( if group is None: group = dist.distributed_c10d._get_default_group() - if dist.get_world_size(group) <= 1: + if group.size() <= 1: return NoWait(a2a_pooled_embs_tensor) - myreq = Request(group, device=a2a_pooled_embs_tensor.device) a2ai = All2AllPooledInfo( batch_size_per_rank=batch_size_per_rank, dim_sum_per_rank=dim_sum_per_rank, @@ -320,11 +391,208 @@ def alltoall_pooled( cumsum_dim_sum_per_rank_tensor=cumsum_dim_sum_per_rank_tensor, codecs=codecs, ) - # pyre-fixme[16]: `All2All_Pooled_Req` has no attribute `apply`. + + if get_use_sync_collectives(): + return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor)) + + myreq = Request(group, device=a2a_pooled_embs_tensor.device) All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor) return myreq +def pg_name(pg: dist.ProcessGroup) -> str: + return dist._functional_collectives._resolve_group_name(pg, "") + + +def all2all_pooled_sync( + pg: dist.ProcessGroup, a2ai: All2AllPooledInfo, input_embeddings: Tensor +) -> Tensor: + my_rank = pg.rank() + + (B_global, D_local_sum) = input_embeddings.shape + + dim_sum_per_rank = a2ai.dim_sum_per_rank + batch_size_per_rank = a2ai.batch_size_per_rank + B_local = batch_size_per_rank[my_rank] + + assert B_global == sum(batch_size_per_rank) + + sharded_input_embeddings = input_embeddings.view(-1) + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + qcomm_ctx = codecs.forward.create_context() + sharded_input_embeddings = codecs.forward.encode( + sharded_input_embeddings, + qcomm_ctx, + ) + output_split_sizes = [ + codecs.forward.calc_quantized_size( + B_local * D_rank_sum, + qcomm_ctx, + ) + for D_rank_sum in dim_sum_per_rank + ] + input_split_sizes = [ + codecs.forward.calc_quantized_size( + D_local_sum * B_rank, + qcomm_ctx, + ) + for B_rank in batch_size_per_rank + ] + else: + output_split_sizes = [B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank] + input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank] + qcomm_ctx = None + + with record_function("## alltoall_fwd_single ##"): + sharded_output_embeddings = AllToAllSingle.apply( + sharded_input_embeddings, + output_split_sizes, + input_split_sizes, + pg_name(pg), + pg.size(), + get_gradient_division(), + ) + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + sharded_output_embeddings = codecs.forward.decode( + sharded_output_embeddings, + qcomm_ctx, + ) + + if is_torchdynamo_compiling(): + # Default implementation fails on backward with unbacked symints in full Ads model compilation. + # This is workaround to do desired split-cat under tracing in custom_op. + # TODO: Remove when pt2 symbolic shapes default split - cat can be compiled forward and backward with unbacked symints + return torch.ops.torchrec._split_1d_cat_2d( + sharded_output_embeddings, B_local, dim_sum_per_rank + ) + + outputs_by_rank = sharded_output_embeddings.split(output_split_sizes) + return torch.cat([output.view(B_local, -1) for output in outputs_by_rank], dim=1) + + +def variable_batch_alltoall_pooled( + a2a_pooled_embs_tensor: Tensor, + batch_size_per_rank_per_feature: List[List[int]], + batch_size_per_feature_pre_a2a: List[int], + emb_dim_per_rank_per_feature: List[List[int]], + group: Optional[dist.ProcessGroup] = None, + codecs: Optional[QuantizedCommCodecs] = None, +) -> Awaitable[Tensor]: + + if group is None: + group = dist.distributed_c10d._get_default_group() + + if dist.get_world_size(group) <= 1: + return NoWait(a2a_pooled_embs_tensor) + + a2ai = VariableBatchAll2AllPooledInfo( + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature, + batch_size_per_feature_pre_a2a=batch_size_per_feature_pre_a2a, + emb_dim_per_rank_per_feature=emb_dim_per_rank_per_feature, + codecs=codecs, + ) + + if get_use_sync_collectives(): + return NoWait( + variable_batch_all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor) + ) + + myreq = Request(group, device=a2a_pooled_embs_tensor.device) + Variable_Batch_All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor) + return myreq + + +def variable_batch_all2all_pooled_sync( + pg: dist.ProcessGroup, + a2ai: VariableBatchAll2AllPooledInfo, + input_embeddings: Tensor, +) -> Tensor: + my_rank = pg.rank() + + # get input splits + world_size = dist.get_world_size(pg) + input_split_sizes = [0 for _ in range(world_size)] + if a2ai.batch_size_per_rank_per_feature: + for i in range(world_size): + curr_size = 0 + for batch_size, emb_dim in zip( + a2ai.batch_size_per_rank_per_feature[i], + a2ai.emb_dim_per_rank_per_feature[my_rank], + ): + curr_size += batch_size * emb_dim + input_split_sizes[i] = curr_size + a2ai.input_splits = input_split_sizes + + # get output splits + output_split_sizes = [0 for _ in range(world_size)] + ind = 0 + for i in range(world_size): + curr_size = 0 + for emb_dim in a2ai.emb_dim_per_rank_per_feature[i]: + curr_size += a2ai.batch_size_per_feature_pre_a2a[ind] * emb_dim + ind += 1 + output_split_sizes[i] = curr_size + a2ai.output_splits = output_split_sizes + + sharded_input_embeddings = input_embeddings.view(-1) + qcomm_ctx = None + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + qcomm_ctx = codecs.forward.create_context() + sharded_input_embeddings = codecs.forward.encode( + sharded_input_embeddings, + qcomm_ctx, + ) + output_split_sizes = [ + codecs.forward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in output_split_sizes + ] + input_split_sizes = [ + codecs.forward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in input_split_sizes + ] + + with record_function("## alltoall_fwd_single ##"): + if pg._get_backend_name() == "custom": + sharded_output_embeddings = torch.empty( + sum(output_split_sizes), + device=sharded_input_embeddings.device, + dtype=sharded_input_embeddings.dtype, + ) + s0 = sharded_output_embeddings.size(0) + # Bad assumption that our rank GE than other + torch._check(s0 <= sharded_input_embeddings.size(0)) + sharded_output_embeddings.copy_(sharded_input_embeddings[:s0]) + else: + sharded_output_embeddings = AllToAllSingle.apply( + sharded_input_embeddings, + output_split_sizes, + input_split_sizes, + pg_name(pg), + pg.size(), + get_gradient_division(), + ) + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + sharded_output_embeddings = codecs.forward.decode( + sharded_output_embeddings, + qcomm_ctx, + ) + return sharded_output_embeddings + + def alltoall_sequence( # (T, B, L_i * D) flattened a2a_sequence_embs_tensor: Tensor, @@ -353,12 +621,12 @@ def alltoall_sequence( backward_recat_tensor (Tensor): recat tensor for backward. lengths_after_sparse_data_all2all (Tensor): lengths of sparse features after AlltoAll. - input_splits (Tensor): input splits. - output_splits (Tensor): output splits. - variable_batch_size (bool): whether varibale batch size is enabled - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + input_splits (List[int]): input splits. + output_splits (List[int]): output splits. + variable_batch_size (bool): whether variable batch size is enabled. + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. - codecs: Optional[QuantizedCommCodecs]: Quantized communication codecs + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Returns: Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor. @@ -373,7 +641,6 @@ def alltoall_sequence( if dist.get_world_size(group) <= 1: return NoWait(a2a_sequence_embs_tensor) - myreq = Request(group, device=a2a_sequence_embs_tensor.device) a2ai = All2AllSequenceInfo( embedding_dim=a2a_sequence_embs_tensor.shape[1], lengths_after_sparse_data_all2all=lengths_after_sparse_data_all2all, @@ -386,72 +653,93 @@ def alltoall_sequence( ) # sequence of embeddings, bags are definitely non-uniform - # pyre-fixme[16]: `All2All_Seq_Req` has no attribute `apply`. + if get_use_sync_collectives(): + return NoWait(all2all_sequence_sync(group, a2ai, a2a_sequence_embs_tensor)) + + myreq = Request(group, device=a2a_sequence_embs_tensor.device) All2All_Seq_Req.apply(group, myreq, a2ai, a2a_sequence_embs_tensor) return myreq -def alltoallv( - inputs: List[Tensor], - out_split: Optional[List[int]] = None, - per_rank_split_lengths: Optional[List[int]] = None, - group: Optional[dist.ProcessGroup] = None, - codecs: Optional[QuantizedCommCodecs] = None, -) -> Awaitable[List[Tensor]]: - """ - Performs `alltoallv` operation for a list of input embeddings. Each process scatters - the list to all processes in the group. - - Args: - input (List[Tensor]): list of tensors to scatter, one per rank. The tensors in - the list usually have different lengths. - out_split (Optional[List[int]]): output split sizes (or dim_sum_per_rank), if - not specified, we will use `per_rank_split_lengths` to construct a output - split with the assumption that all the embs have the same dimension. - per_rank_split_lengths (Optional[List[int]]): split lengths per rank. If not - specified, the `out_split` must be specified. - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the - default process group will be used. - - Returns: - Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting list of tensors. - - .. warning:: - `alltoallv` is experimental and subject to change. - """ - - if group is None: - group = dist.distributed_c10d._get_default_group() - - world_size = dist.get_world_size(group) - my_rank = dist.get_rank(group) - - myreq = Request(group, device=inputs[0].device) - B_global, _ = inputs[0].size() - D_local_list = [e.size()[1] for e in inputs] - B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global) - - if out_split is not None: - dims_sum_per_rank = out_split - elif per_rank_split_lengths is not None: - # all the embs have the same dimension - dims_sum_per_rank = [s * D_local_list[0] for s in per_rank_split_lengths] +def all2all_sequence_sync( + pg: dist.ProcessGroup, + a2ai: All2AllSequenceInfo, + sharded_input_embeddings: Tensor, +) -> Tensor: + world_size = pg.size() + D = a2ai.embedding_dim + forward_recat_tensor = a2ai.forward_recat_tensor + variable_batch_size = a2ai.variable_batch_size + lengths_after_sparse_data_all2all = a2ai.lengths_after_sparse_data_all2all * D + input_splits = [i * D for i in a2ai.output_splits] + output_splits = [i * D for i in a2ai.input_splits] + + a2ai.input_splits = input_splits + a2ai.output_splits = output_splits + + local_T = lengths_after_sparse_data_all2all.shape[0] + if local_T > 0: + with record_function("## alltoall_seq_embedding_fwd_permute ##"): + if not variable_batch_size: + ( + permuted_lengths_after_sparse_data_all2all, + sharded_input_embeddings, + _, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + forward_recat_tensor, + lengths_after_sparse_data_all2all.view(local_T * world_size, -1), + sharded_input_embeddings.view(-1), + None, + sharded_input_embeddings.numel(), + ) + else: + ( + permuted_lengths_after_sparse_data_all2all, + sharded_input_embeddings, + _, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + forward_recat_tensor, + lengths_after_sparse_data_all2all.view(-1), + sharded_input_embeddings.view(-1), + None, + sharded_input_embeddings.numel(), + ) else: - raise RuntimeError("Need to specify either out_split or per_rank_split_lengths") - - a2ai = All2AllVInfo( - dims_sum_per_rank=dims_sum_per_rank, - B_local=B_local, - B_local_list=B_local_list, - D_local_list=D_local_list, - B_global=B_global, - codecs=codecs, - ) - - # pyre-fixme[16]: `All2Allv_Req` has no attribute `apply`. - All2Allv_Req.apply(group, myreq, a2ai, inputs) + # Variable is not used in sync mode, left for conformity with async path + permuted_lengths_after_sparse_data_all2all = None # noqa: F841 + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + qcomm_ctx = codecs.forward.create_context() + # pyre-ignore [16] + sharded_input_embeddings = a2ai.codecs.forward.encode( + sharded_input_embeddings, qcomm_ctx + ) + output_splits = [ + a2ai.codecs.forward.calc_quantized_size(x, qcomm_ctx) for x in output_splits + ] + input_splits = [ + a2ai.codecs.forward.calc_quantized_size(x, qcomm_ctx) for x in input_splits + ] + else: + qcomm_ctx = None + + with record_function("## alltoall_seq_embedding_fwd_single ##"): + sharded_output_embeddings = AllToAllSingle.apply( + sharded_input_embeddings, + output_splits, + input_splits, + pg_name(pg), + pg.size(), + get_gradient_division(), + ) - return myreq + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + sharded_output_embeddings = codecs.forward.decode( + sharded_output_embeddings, qcomm_ctx + ) + return sharded_output_embeddings.view(-1, D) def reduce_scatter_pooled( @@ -466,8 +754,9 @@ def reduce_scatter_pooled( Args: inputs (List[Tensor]): list of tensors to scatter, one per rank. - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Returns: Awaitable[Tensor]: async work handle (Awaitable), which can be `wait()` later to get the resulting tensor. @@ -479,31 +768,58 @@ def reduce_scatter_pooled( if group is None: group = dist.distributed_c10d._get_default_group() - if dist.get_world_size(group) <= 1: - return NoWait(inputs[dist.get_rank(group)]) + if group.size() <= 1: + return NoWait(inputs[group.rank()]) - myreq = Request(group, device=inputs[0].device) rsi = ReduceScatterInfo( input_sizes=[tensor.size() for tensor in inputs], codecs=codecs ) - # pyre-fixme[16] + + if get_use_sync_collectives(): + return NoWait(reduce_scatter_sync(group, rsi, *inputs)) + + myreq = Request(group, device=inputs[0].device) ReduceScatter_Req.apply(group, myreq, rsi, *inputs) return myreq +def reduce_scatter_sync( + pg: dist.ProcessGroup, + rsi: ReduceScatterInfo, + *inputs: Any, +) -> Tensor: + if rsi.codecs is not None: + # pyre-ignore + inputs = [rsi.codecs.forward.encode(input) for input in inputs] + + with record_function("## reduce_scatter ##"): + output = torch.ops.torchrec.reduce_scatter_tensor( + torch.cat(inputs), + reduceOp="sum", + group_size=pg.size(), + group_name=pg_name(pg), + gradient_division=get_gradient_division(), + ) + if rsi.codecs is not None: + output = rsi.codecs.forward.decode(output) + return output + + def reduce_scatter_base_pooled( - inputs: Tensor, + input: Tensor, group: Optional[dist.ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None, ) -> Awaitable[Tensor]: """ - Reduces then scatters a flattened pooled embeddings tensor to all processes in a group. - Input tensor is of size output tensor size times world size. + Reduces then scatters a flattened pooled embeddings tensor to all processes in a + group. + Input tensor is of size `output_tensor_size * world_size`. Args: - inputs (Tensor): flattened tensor to scatter, . - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + input (Tensor): flattened tensor to scatter. + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Returns: Awaitable[Tensor]: async work handle (Awaitable), which can be `wait()` later to get the resulting tensor. @@ -516,27 +832,54 @@ def reduce_scatter_base_pooled( group = dist.distributed_c10d._get_default_group() if dist.get_world_size(group) <= 1: - return NoWait(inputs) + return NoWait(input) - myreq = Request(group, device=inputs.device) - rsi = ReduceScatterBaseInfo(input_sizes=inputs.size(), codecs=codecs) - # pyre-fixme[16] - ReduceScatterBase_Req.apply(group, myreq, rsi, inputs) + rsi = ReduceScatterBaseInfo(input_sizes=input.size(), codecs=codecs) + + if get_use_sync_collectives(): + return NoWait(reduce_scatter_base_sync(group, rsi, input)) + + myreq = Request(group, device=input.device) + ReduceScatterBase_Req.apply(group, myreq, rsi, input) return myreq +def reduce_scatter_base_sync( + pg: dist.ProcessGroup, + rsi: ReduceScatterBaseInfo, + inputs: Tensor, +) -> Tensor: + my_size = pg.size() + assert inputs.size(0) % my_size == 0 + if rsi.codecs is not None: + inputs = rsi.codecs.forward.encode(inputs) + + with record_function("## reduce_scatter_base ##"): + output = torch.ops.torchrec.reduce_scatter_tensor( + inputs, + reduceOp="sum", + group_size=pg.size(), + group_name=pg_name(pg), + gradient_division=get_gradient_division(), + ) + if rsi.codecs is not None: + output = rsi.codecs.forward.decode(output) + return output + + def all_gather_base_pooled( input: Tensor, group: Optional[dist.ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None, ) -> Awaitable[Tensor]: """ - All-gathers tensors from all processes in a group to form a flattened pooled embeddings tensor. - Input tensor is of size output tensor size divided by world size. + All-gathers tensors from all processes in a group to form a flattened pooled + embeddings tensor. + Input tensor is of size `output_tensor_size / world_size`. Args: - input (Tensor): tensor to gather, . - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + input (Tensor): tensor to gather. + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. Returns: @@ -549,16 +892,39 @@ def all_gather_base_pooled( if group is None: group = dist.distributed_c10d._get_default_group() + agi = AllGatherBaseInfo(input_size=input.size(), codecs=codecs) if dist.get_world_size(group) <= 1: return NoWait(input) + if get_use_sync_collectives(): + return NoWait(all_gather_base_sync(group, agi, input)) + myreq = Request(group, device=input.device) - agi = AllGatherBaseInfo(input_size=input.size(), codecs=codecs) - # pyre-fixme[16] AllGatherBase_Req.apply(group, myreq, agi, input) return myreq +def all_gather_base_sync( + pg: dist.ProcessGroup, + agi: AllGatherBaseInfo, + input: Tensor, +) -> Tensor: + if agi.codecs is not None: + input = agi.codecs.forward.encode(input) + + with record_function("## all_gather_base ##"): + outputs = torch.ops.torchrec.all_gather_into_tensor( + input, + gather_dim=0, + group_name=pg_name(pg), + group_size=pg.size(), + gradient_division=get_gradient_division(), + ) + if agi.codecs is not None: + outputs = agi.codecs.forward.decode(outputs) + return outputs + + def reduce_scatter_v_pooled( input: Tensor, input_splits: List[int], @@ -566,14 +932,14 @@ def reduce_scatter_v_pooled( codecs: Optional[QuantizedCommCodecs] = None, ) -> Awaitable[Tensor]: """ - Performs reduce-scatter-v operation for a pooled embeddings tensor split unevenly into world - size number of chunks. The result of the reduce operation gets scattered to all - processes in the group according to input_splits. + Performs reduce-scatter-v operation for a pooled embeddings tensor split unevenly + into world size number of chunks. The result of the reduce operation gets scattered + to all processes in the group according to `input_splits`. Args: - input (Tensor): tensors to scatter, one per rank. + input (Tensor): tensor to scatter. input_splits (List[int]): input splits. - group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + group (Optional[dist.ProcessGroup]): the process group to work on. If None, the default process group will be used. Returns: @@ -589,15 +955,16 @@ def reduce_scatter_v_pooled( if dist.get_world_size(group) <= 1: return NoWait(input) - myreq = Request(group, device=input.device) input_size = list(input.size()) input_sizes = [ - torch.Size( - [ip_split if d == 0 else input_size[d] for d in range(len(input_size))] - ) + [ip_split if d == 0 else input_size[d] for d in range(len(input_size))] for ip_split in input_splits ] - equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits) + + equal_splits = False + if not is_torchdynamo_compiling(): + # We can not check during tracing equality of splits -> fallback on general + equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits) rsvi = ReduceScatterVInfo( input_sizes=input_sizes, @@ -606,7 +973,117 @@ def reduce_scatter_v_pooled( total_input_size=input_size, codecs=codecs, ) - # pyre-fixme[16]: `ReduceScatterV_Req` has no attribute `apply`. + + if get_use_sync_collectives(): + return NoWait(reduce_scatter_v_sync(group, rsvi, input)) + + myreq = Request(group, device=input.device) + ReduceScatterV_Req.apply(group, myreq, rsvi, input) + return myreq + + +def reduce_scatter_v_sync( + pg: dist.ProcessGroup, + rsi: ReduceScatterVInfo, + input: Tensor, +) -> Tensor: + world_size = pg.size() + rank = pg.rank() + + if rsi.codecs is not None: + input = rsi.codecs.forward.encode(input) + + if rsi.equal_splits: + with record_function("## reduce_scatter_base ##"): + output = torch.ops.torchrec.reduce_scatter_tensor( + input, + reduceOp="sum", + group_size=pg.size(), + group_name=pg_name(pg), + gradient_division=get_gradient_division(), + ) + else: + with record_function("## reduce_scatter_v_via_all_to_all_single ##"): + input_splits = rsi.input_splits + output_splits = [rsi.input_splits[rank]] * world_size + # TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added + a2a_output = AllToAllSingle.apply( + input, + output_splits, + input_splits, + pg_name(pg), + pg.size(), + get_gradient_division(), + ) + output = torch.sum( + torch.stack(torch.split(a2a_output, output_splits)), dim=0 + ) + + if rsi.codecs is not None: + output = rsi.codecs.forward.decode(output) + + return output + + +def reduce_scatter_v_per_feature_pooled( + input: Tensor, + batch_size_per_rank_per_feature: List[List[int]], + embedding_dims: List[int], + group: Optional[dist.ProcessGroup] = None, + codecs: Optional[QuantizedCommCodecs] = None, +) -> Awaitable[Tensor]: + """ + Performs reduce-scatter-v operation for a 1-d pooled embeddings tensor of variable + batch size per feature split unevenly into world size number of chunks. The result + of the reduce operation gets scattered to all processes in the group. + + Args: + input (Tensor): tensors to scatter, one per rank. + batch_size_per_rank_per_feature (List[List[int]]): batch size per rank per + feature used to determine input splits. + embedding_dims (List[int]): embedding dimensions per feature used to determine + input splits. + group (Optional[dist.ProcessGroup]): The process group to work on. If None, the + default process group will be used. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. + + Returns: + Awaitable[Tensor]: async work handle (Awaitable), which can be `wait()` later to get the resulting tensor. + + .. warning:: + `reduce_scatter_v_per_feature_pooled` is experimental and subject to change. + """ + + if group is None: + group = dist.distributed_c10d._get_default_group() + + world_size = group.size() + if world_size <= 1: + return NoWait(input) + + input_splits = [0 for _ in range(world_size)] + if batch_size_per_rank_per_feature: + for rank in range(world_size): + rank_splits = 0 + for batch_size, emb_dim in zip( + batch_size_per_rank_per_feature[rank], embedding_dims + ): + rank_splits += batch_size * emb_dim + input_splits[rank] = rank_splits + input_sizes = [[s] for s in input_splits] + + rsvi = ReduceScatterVInfo( + input_sizes=input_sizes, + input_splits=input_splits, + equal_splits=False, + total_input_size=list(input.size()), + codecs=codecs, + ) + + if get_use_sync_collectives(): + return NoWait(reduce_scatter_v_sync(group, rsvi, input)) + + myreq = Request(group, device=input.device) ReduceScatterV_Req.apply(group, myreq, rsvi, input) return myreq @@ -673,32 +1150,46 @@ def forward( dim_sum_per_rank = a2ai.dim_sum_per_rank batch_size_per_rank = a2ai.batch_size_per_rank B_local = batch_size_per_rank[my_rank] - assert B_global == sum(batch_size_per_rank) - sharded_input_embeddings = input_embeddings.view(-1) + assert B_global == sum(batch_size_per_rank) if a2ai.codecs is not None: codecs = none_throws(a2ai.codecs) qcomm_ctx = codecs.forward.create_context() + padded_D_local_sum, padding_size = codecs.forward.padded_size( + input_embeddings, dim_sum_per_rank, my_rank, qcomm_ctx + ) + + if padding_size == 0: + sharded_input_embeddings = input_embeddings.view(-1) + else: + sharded_input_embeddings = input_embeddings sharded_input_embeddings = codecs.forward.encode( sharded_input_embeddings, qcomm_ctx, ) + padded_dim_sum_per_rank = ( + qcomm_ctx.padded_dim_sum_per_rank + if qcomm_ctx is not None + and qcomm_ctx.padded_dim_sum_per_rank is not None + else dim_sum_per_rank + ) output_split_sizes = [ codecs.forward.calc_quantized_size( B_local * D_rank_sum, qcomm_ctx, ) - for D_rank_sum in dim_sum_per_rank + for D_rank_sum in padded_dim_sum_per_rank ] input_split_sizes = [ codecs.forward.calc_quantized_size( - D_local_sum * B_rank, + padded_D_local_sum * B_rank, qcomm_ctx, ) for B_rank in batch_size_per_rank ] else: + sharded_input_embeddings = input_embeddings.view(-1) output_split_sizes = [ B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank ] @@ -790,11 +1281,32 @@ def forward( myreq.qcomm_ctx, ) + padded_dim_sum_per_rank = ( + myreq.qcomm_ctx.padded_dim_sum_per_rank + if myreq.qcomm_ctx is not None + and myreq.qcomm_ctx.padded_dim_sum_per_rank is not None + else dim_sum_per_rank + ) outputs_by_rank = sharded_output_embeddings.split( - [B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank] + [B_local * D_rank_sum for D_rank_sum in padded_dim_sum_per_rank] ) + final_dim_sum_per_rank = padded_dim_sum_per_rank + if ( + myreq.qcomm_ctx is not None + and myreq.qcomm_ctx.padded_dim_sum_per_rank is not None + ): + outputs_by_rank = [ + output.view(B_local, -1)[:, :dim_sum] + for output, dim_sum in zip(outputs_by_rank, dim_sum_per_rank) + ] + final_dim_sum_per_rank = dim_sum_per_rank + result = torch.cat( - [output.view(B_local, -1) for output in outputs_by_rank], dim=1 + [ + output.view(B_local, dim) + for output, dim in zip(outputs_by_rank, final_dim_sum_per_rank) + ], + dim=1, ) return result @@ -869,6 +1381,210 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: return (None, None, myreq.dummy_tensor) +class Variable_Batch_All2All_Pooled_Req(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + a2ai: VariableBatchAll2AllPooledInfo, + input_embeddings: Tensor, + ) -> Tensor: + my_rank = dist.get_rank(pg) + + # get input splits + world_size = dist.get_world_size(pg) + input_split_sizes = [0 for _ in range(world_size)] + if a2ai.batch_size_per_rank_per_feature: + for i in range(world_size): + curr_size = 0 + for batch_size, emb_dim in zip( + a2ai.batch_size_per_rank_per_feature[i], + a2ai.emb_dim_per_rank_per_feature[my_rank], + ): + curr_size += batch_size * emb_dim + input_split_sizes[i] = curr_size + a2ai.input_splits = input_split_sizes + + # get output splits + output_split_sizes = [0 for _ in range(world_size)] + ind = 0 + for i in range(world_size): + curr_size = 0 + for emb_dim in a2ai.emb_dim_per_rank_per_feature[i]: + curr_size += a2ai.batch_size_per_feature_pre_a2a[ind] * emb_dim + ind += 1 + output_split_sizes[i] = curr_size + a2ai.output_splits = output_split_sizes + + sharded_input_embeddings = input_embeddings.view(-1) + qcomm_ctx = None + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + qcomm_ctx = codecs.forward.create_context() + sharded_input_embeddings = codecs.forward.encode( + sharded_input_embeddings, + qcomm_ctx, + ) + output_split_sizes = [ + codecs.forward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in output_split_sizes + ] + input_split_sizes = [ + codecs.forward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in input_split_sizes + ] + + sharded_output_embeddings = torch.empty( + sum(output_split_sizes), + dtype=sharded_input_embeddings.dtype, + device=sharded_input_embeddings.device, + ) + + with record_function("## alltoall_fwd_single ##"): + req = dist.all_to_all_single( + output=sharded_output_embeddings, + input=sharded_input_embeddings, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=pg, + async_op=True, + ) + + myreq.req = req + myreq.tensor = sharded_output_embeddings + myreq.qcomm_ctx = qcomm_ctx + myreq.a2ai = a2ai + myreq.wait_function = Variable_Batch_All2All_Pooled_Wait + ctx.myreq = myreq + ctx.pg = pg + return myreq.dummy_tensor + + @staticmethod + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: + myreq = ctx.myreq + a2ai = myreq.a2ai + assert myreq.req is not None + myreq.req.wait() + if isinstance(myreq.req, dist.Work): + myreq.req.wait() + + myreq.req = None + grad_output = myreq.tensor + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + grad_input = codecs.backward.decode(grad_output, myreq.qcomm_ctx) + else: + grad_input = grad_output + if GRADIENT_DIVISION: + grad_input.div_(dist.get_world_size(ctx.pg)) + myreq.tensor = None + myreq.dummy_tensor = None + return (None, None, None, grad_input) + + +class Variable_Batch_All2All_Pooled_Wait(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + *dummy_tensor: Tensor, + ) -> Tensor: + a2ai = myreq.a2ai + ctx.a2ai = a2ai + assert myreq.req is not None + if isinstance(myreq.req, dist.Work): + myreq.req.wait() + sharded_output_embeddings = myreq.tensor + myreq.req = None + myreq.tensor = None + ctx.pg = pg + ctx.myreq = myreq + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + sharded_output_embeddings = codecs.forward.decode( + sharded_output_embeddings, + myreq.qcomm_ctx, + ) + # the return result is a 1-d tensor, like: f_0_s_0, f_0_s1, ..., f_n_s_0, f_n_s_k + # f_0, f_1, ... , f_n are ordered by features on each rank + return sharded_output_embeddings + + @staticmethod + # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: + myreq = ctx.myreq + a2ai = ctx.a2ai + pg = ctx.pg + + assert a2ai.input_splits is not None + assert a2ai.output_splits is not None + input_split_sizes = a2ai.output_splits + output_split_sizes = a2ai.input_splits + + sharded_grad_output = grad_output.contiguous() + qcomm_ctx = None + + if a2ai.codecs is not None: + codecs = none_throws(a2ai.codecs) + qcomm_ctx = codecs.backward.create_context() + sharded_grad_output = codecs.backward.encode( + sharded_grad_output, + qcomm_ctx, + ) + input_split_sizes = [ + codecs.backward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in input_split_sizes + ] + output_split_sizes = [ + codecs.backward.calc_quantized_size( + split, + qcomm_ctx, + ) + for split in output_split_sizes + ] + + sharded_grad_input = torch.empty( + sum(output_split_sizes), + device=sharded_grad_output.device, + dtype=sharded_grad_output.dtype, + ) + with record_function("## alltoall_bwd_single ##"): + req = dist.all_to_all_single( + output=sharded_grad_input, + input=sharded_grad_output, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=pg, + async_op=True, + ) + myreq.req = req + myreq.tensor = sharded_grad_input + myreq.qcomm_ctx = qcomm_ctx + + return (None, None, myreq.dummy_tensor) + + class All2All_Seq_Req(Function): @staticmethod # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. @@ -1013,6 +1729,8 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: None, sharded_grad_input.numel(), ) + if GRADIENT_DIVISION: + sharded_grad_input.div_(dist.get_world_size(ctx.pg)) return (None, None, None, sharded_grad_input.view(-1, D)) @@ -1232,7 +1950,12 @@ def forward( device=inputs[my_rank].device, ) with record_function("## reduce_scatter ##"): - req = dist.reduce_scatter(output, list(inputs), group=pg, async_op=True) + req = dist.reduce_scatter( + output, + list(inputs), + group=pg, + async_op=True, + ) myreq.req = req myreq.tensor = output myreq.wait_function = ReduceScatter_Wait @@ -1333,8 +2056,13 @@ def forward( if rsi.codecs is not None: inputs = rsi.codecs.forward.encode(inputs) output = inputs.new_empty((inputs.size(0) // my_size, inputs.size(1))) - with record_function("## reduce_scatter_base ##"): - req = dist._reduce_scatter_base(output, inputs, group=pg, async_op=True) + with record_function("## reduce_scatter_tensor ##"): + req = dist.reduce_scatter_tensor( + output, + inputs, + group=pg, + async_op=True, + ) myreq.req = req myreq.tensor = output myreq.wait_function = ReduceScatterBase_Wait @@ -1397,7 +2125,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: grad_output = rsi.codecs.backward.encode(grad_output) grad_inputs = grad_output.new_empty(rsi.input_sizes) with record_function("## reduce_scatter_base_bw (all_gather) ##"): - req = dist._all_gather_base( + req = dist.all_gather_into_tensor( grad_inputs, grad_output.contiguous(), group=ctx.pg, @@ -1425,8 +2153,13 @@ def forward( input = agi.codecs.forward.encode(input) outputs = input.new_empty((input.size(0) * my_size, input.size(1))) - with record_function("## all_gather_base ##"): - req = dist._all_gather_base(outputs, input, group=pg, async_op=True) + with record_function("## all_gather_into_tensor ##"): + req = dist.all_gather_into_tensor( + outputs, + input, + group=pg, + async_op=True, + ) myreq.req = req myreq.tensor = outputs myreq.wait_function = AllGatherBase_Wait @@ -1489,7 +2222,7 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]: grad_outputs = agi.codecs.backward.encode(grad_outputs) grad_input = grad_outputs.new_empty(agi.input_size) with record_function("## all_gather_base_bw (reduce_scatter) ##"): - req = dist._reduce_scatter_base( + req = dist.reduce_scatter_tensor( grad_input, grad_outputs.contiguous(), group=ctx.pg, @@ -1519,11 +2252,16 @@ def forward( output = input.new_empty(rsi.input_sizes[my_rank]) - # Use dist._reduce_scatter_base when a vector reduce-scatter is not needed + # Use dist.reduce_scatter_tensor when a vector reduce-scatter is not needed # else use dist.reduce_scatter which internally supports vector reduce-scatter if rsi.equal_splits: - with record_function("## reduce_scatter_base ##"): - req = dist._reduce_scatter_base(output, input, group=pg, async_op=True) + with record_function("## reduce_scatter_tensor ##"): + req = dist.reduce_scatter_tensor( + output, + input, + group=pg, + async_op=True, + ) else: with record_function("## reduce_scatter_v ##"): req = dist.reduce_scatter( @@ -1599,7 +2337,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: if rsi.equal_splits: with record_function("## reduce_scatter_base_bw (all_gather) ##"): - req = dist._all_gather_base( + req = dist.all_gather_into_tensor( grad_input, grad_output.contiguous(), group=ctx.pg, @@ -1616,3 +2354,204 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: myreq.req = req myreq.tensor = grad_input return (None, None, myreq.dummy_tensor) + + +if not torch._running_with_deploy(): # noqa C901 + # Torch Library op def can not be used in Deploy + class AllToAllSingle(torch.autograd.Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + group_name: str, + group_size: int, + gradient_division: bool, + ) -> Tensor: + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + ctx.group_name = group_name + ctx.group_size = group_size + ctx.gradient_division = gradient_division + return torch.distributed._functional_collectives.all_to_all_single( + input, output_split_sizes, input_split_sizes, group_name + ) + + @staticmethod + # pyre-ignore + def backward(ctx, grad): + grad = torch.distributed._functional_collectives.all_to_all_single( + grad, + ctx.output_split_sizes, + ctx.input_split_sizes, + ctx.group_name, + ) + if ctx.gradient_division: + grad.div_(ctx.group_size) + + return grad, None, None, None, None, None + + # torchrec::reduce_scatter_tensor + @torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=()) + def reduce_scatter_tensor( + input: Tensor, + reduceOp: str, + group_size: int, + group_name: str, + gradient_division: bool, + ) -> Tensor: + out = torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduceOp, + group_size, + group_name, + ) + return torch.ops._c10d_functional.wait_tensor(out) + + @torch.library.register_fake("torchrec::reduce_scatter_tensor") + def reduce_scatter_tensor_fake( + input: Tensor, + reduceOp: str, + group_size: int, + group_name: str, + gradient_division: bool, + ) -> Tensor: + return torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduceOp, + group_size, + group_name, + ) + + # pyre-ignore + def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None: + _, _, group_size, group_name, gradient_division = inputs + ctx.group_size = group_size + ctx.group_name = group_name + ctx.gradient_division = gradient_division + + # pyre-ignore + def reduce_scatter_tensor_backward(ctx, grad): + # TODO(ivankobzarev): Support codecs(quantization) on backward + out = torch.ops._c10d_functional.all_gather_into_tensor( + grad, + ctx.group_size, + ctx.group_name, + ) + grad = torch.ops._c10d_functional.wait_tensor(out) + if ctx.gradient_division: + grad.div_(ctx.group_size) + + return grad, None, None, None, None, None + + torch.library.register_autograd( + "torchrec::reduce_scatter_tensor", + reduce_scatter_tensor_backward, + setup_context=reduce_scatter_tensor_setup_context, + ) + + # torchrec::all_gather_into_tensor + @torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=()) + def all_gather_into_tensor( + shard: Tensor, + gather_dim: int, + group_size: int, + group_name: str, + gradient_division: bool, + ) -> Tensor: + out = torch.ops._c10d_functional.all_gather_into_tensor( + shard, group_size, group_name + ) + return torch.ops._c10d_functional.wait_tensor(out) + + @torch.library.register_fake("torchrec::all_gather_into_tensor") + def all_gather_into_tensor_fake( + shard: Tensor, + gather_dim: int, + group_size: int, + group_name: str, + gradient_division: bool, + ) -> Tensor: + return torch.ops._c10d_functional.all_gather_into_tensor( + shard, group_size, group_name + ) + + # pyre-ignore + def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None: + _, gather_dim, group_size, group_name, gradient_division = inputs + ctx.group_size = group_size + ctx.group_name = group_name + ctx.gradient_division = gradient_division + + # pyre-ignore + def all_gather_into_tensor_backward(ctx, grad): + # TODO(ivankobzarev): Support codecs(quantization) on backward + out = torch.ops._c10d_functional.reduce_scatter_tensor( + grad, + "sum", + ctx.group_size, + ctx.group_name, + ) + grad = torch.ops._c10d_functional.wait_tensor(out) + if ctx.gradient_division: + grad.div_(ctx.group_size) + + return grad, None, None, None, None + + torch.library.register_autograd( + "torchrec::all_gather_into_tensor", + all_gather_into_tensor_backward, + setup_context=all_gather_into_tensor_setup_context, + ) + + @torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=()) + def _split_1d_cat_2d_impl( + t: torch.Tensor, dim0: int, dim1s: List[int] + ) -> torch.Tensor: + torch._check_is_size(dim0) + [torch._check_is_size(dim1) for dim1 in dim1s] + splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s]) + return torch.cat( + [s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)], + dim=1, + ) + + @torch.library.register_fake("torchrec::_split_1d_cat_2d") + def _split_1d_cat_2d_impl_abstract( + t: torch.Tensor, dim0: int, dim1s: List[int] + ) -> torch.Tensor: + return t.new_empty([dim0, sum(dim1s)]) + + @torch.library.custom_op( + "torchrec::_split_1d_cat_2d_backward_impl", mutates_args=() + ) + def _split_1d_cat_2d_backward_impl( + grad: torch.Tensor, dim1s: List[int] + ) -> torch.Tensor: + splits = grad.split(dim1s, dim=1) + return torch.cat([s.reshape(-1) for s in splits], dim=0) + + @torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl") + def _split_1d_cat_2d_backward_impl_fake( + grad: torch.Tensor, dim1s: List[int] + ) -> torch.Tensor: + return grad.new_empty([grad.numel()]) + + # pyre-ignore + def _split_1d_cat_2d_backward(ctx, grad): + ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s) + return ret, None, None + + # pyre-ignore + def _split_1d_cat_2d_setup_context(ctx, inputs, output): + (x, dim0, dim1s) = inputs + ctx.dim1s = dim1s + + torch.library.register_autograd( + "torchrec::_split_1d_cat_2d", + _split_1d_cat_2d_backward, + setup_context=_split_1d_cat_2d_setup_context, + ) diff --git a/torchrec/distributed/composable/__init__.py b/torchrec/distributed/composable/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/table_batched_embedding_slice.py b/torchrec/distributed/composable/table_batched_embedding_slice.py new file mode 100644 index 000000000..000ff6d19 --- /dev/null +++ b/torchrec/distributed/composable/table_batched_embedding_slice.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, Optional + +import torch + +from torch import nn + + +class TableBatchedEmbeddingSlice(nn.Parameter): + """ + Parameter to represent a slice of a table batched embedding. The slice will be + a view of the TBE of shape (num_embeddings, embedding_dim) and contain consistent .grad + + unlike nn.Parameter, requires_grad is not present and follows requires_grad of TBE.data + + Args: + data (torch.Tensor): original Data (of a TBE) to make a slice of + start_offset (int): + end_offset (int): + num_embeddings (int): + embedding_dim (int): + """ + + __slots__ = [ + "_original_tensor", + "_start_offset", + "_end_offset", + "_num_embeddings", + "_embedding_dim", + ] + + def __init__( + self, + data: torch.Tensor, + start_offset: int, + end_offset: int, + num_embeddings: int, + embedding_dim: int, + ) -> None: + super().__init__() + self._original_tensor: torch.Tensor = data + self._start_offset: int = start_offset + self._end_offset: int = end_offset + self._num_embeddings: int = num_embeddings + self._embedding_dim: int = embedding_dim + self._init_grad: Optional[torch.Tensor] = None + if self._original_tensor.requires_grad: + self.retain_grad() + + def __new__( + cls, + data: torch.Tensor, + start_offset: int, + end_offset: int, + num_embeddings: int, + embedding_dim: int, + ) -> "TableBatchedEmbeddingSlice": + _slice = data[start_offset:end_offset].view(num_embeddings, embedding_dim) + return _slice.as_subclass(cls) + + def __deepcopy__( + self, memo: Dict[int, "TableBatchedEmbeddingSlice"] + ) -> "TableBatchedEmbeddingSlice": + if id(self) in memo: + return memo[id(self)] + else: + result = TableBatchedEmbeddingSlice( + self._original_tensor.clone(memory_format=torch.preserve_format), + self._start_offset, + self._end_offset, + self._num_embeddings, + self._embedding_dim, + ) + memo[id(self)] = result + return result + + @property + def grad(self) -> Optional[torch.Tensor]: + if self._original_tensor.grad is None: + return self._init_grad + return self._original_tensor.grad[self._start_offset : self._end_offset].view( + self._num_embeddings, self._embedding_dim + ) + + @grad.setter + def grad(self, set_grad: torch.Tensor) -> None: + self._init_grad = set_grad + if set_grad is None: + self._original_tensor.grad = None + elif self._original_tensor.grad is not None: + self._original_tensor.grad[self._start_offset : self._end_offset].copy_( + set_grad.view(-1) + ) + + @property + def grad_fn(self) -> None: + # set as leaf node + return None diff --git a/torchrec/distributed/composable/tests/__init__.py b/torchrec/distributed/composable/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/tests/test_ddp.py b/torchrec/distributed/composable/tests/test_ddp.py new file mode 100644 index 000000000..60472bef3 --- /dev/null +++ b/torchrec/distributed/composable/tests/test_ddp.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import tempfile +import unittest + +import torch +from torch.distributed._composable import replicate +from torch.distributed._shard.api import ShardedTensor +from torch.distributed.checkpoint import ( + FileSystemReader, + FileSystemWriter, + load_state_dict, + save_state_dict, +) +from torchrec.distributed.shard import shard as trec_shard, shard_modules +from torchrec.distributed.sharding_plan import column_wise +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.test_utils import skip_if_asan + + +class DDPTest(MultiProcessTestBase): + @classmethod + def _run_init(cls, rank: int, world_size: int) -> None: + with MultiProcessContext(rank, world_size, "nccl") as ctx: + num_float_features = 32 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + m = TestSparseNN( + tables=tables, + num_float_features=num_float_features, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + # Put all tensors on meta device, then init_params should + # materialize them. + for name, param in m._parameters.items(): + if isinstance(param, torch.Tensor): + m._parameters[name] = torch.nn.Parameter( + torch.empty_like(param, device="meta"), + requires_grad=param.requires_grad, + ) + + shard_modules(m, device=ctx.device, init_params=True) + # init_params should move m to `device` + for p in m.parameters(): + assert p.device == ctx.device + + @classmethod + def _run(cls, rank: int, world_size: int, path: str) -> None: + with MultiProcessContext(rank, world_size, "nccl") as ctx: + num_float_features = 32 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + m = TestSparseNN( + tables=tables, + num_float_features=num_float_features, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + # pyre-ignore + m.sparse.ebc = trec_shard( + module=m.sparse.ebc, + device=ctx.device, + plan=column_wise(ranks=list(range(world_size))), + ) + # pyre-ignore + m.sparse.weighted_ebc = trec_shard( + module=m.sparse.weighted_ebc, + device=ctx.device, + plan=column_wise(ranks=list(range(world_size))), + ) + m.over = replicate(m.over) + m.dense = replicate(m.dense) + + ######## run one iteration ######## + _, local_batch = ModelInput.generate( + batch_size=8, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables, + ) + batch = local_batch[0].to(ctx.device) + m(batch)[1].sum().backward() + + state_dict = m.state_dict() + writer = FileSystemWriter(path=path) + reader = FileSystemReader(path=path) + save_state_dict(state_dict, writer) + + p_sum = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + p_sum += p.sum() + p.zero_() + assert p.sum() == 0 + load_state_dict(state_dict, reader) + m.load_state_dict(state_dict) + + p_sum_loaded = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + p_sum_loaded += p.sum() + # TODO: debug why failing on OSS + # assert p_sum.allclose(p_sum_loaded) + + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `torch.cuda.device_count() <= 1` to decorator factory `unittest.skipIf`. + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_checkpoint(self) -> None: + with tempfile.TemporaryDirectory() as path: + self._run_multi_process_test( + callable=self._run, + path=path, + ) + + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `torch.cuda.device_count() <= 1` to decorator factory `unittest.skipIf`. + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_init_params(self) -> None: + self._run_multi_process_test( + callable=self._run_init, + ) diff --git a/torchrec/distributed/composable/tests/test_embedding.py b/torchrec/distributed/composable/tests/test_embedding.py new file mode 100644 index 000000000..88d81f96b --- /dev/null +++ b/torchrec/distributed/composable/tests/test_embedding.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn +from hypothesis import given, settings, Verbosity +from torch.distributed._tensor.api import DTensor +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec import distributed as trec_dist +from torchrec.distributed.embedding import ( + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + row_wise, + table_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ShardedTensor, ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection + +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + backend: str, + local_size: Optional[int] = None, + use_apply_optimizer_in_backward: bool = False, + use_index_dedup: bool = False, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + sharder = EmbeddingCollectionSharder(use_index_dedup=use_index_dedup) + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + + unsharded_model = EmbeddingCollection( + tables=tables, + device=ctx.device, + need_indices=True, + ) + + # syncs model across ranks + torch.manual_seed(0) + for param in unsharded_model.parameters(): + nn.init.uniform_(param, -1, 1) + torch.manual_seed(0) + + if use_apply_optimizer_in_backward: + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embeddings.parameters(), + {"lr": 1.0}, + ) + else: + unsharded_model_optimizer = torch.optim.SGD( + unsharded_model.parameters(), lr=1.0 + ) + + module_sharding_plan = construct_module_sharding_plan( + unsharded_model, + per_param_sharding={ + "table_0": table_wise(rank=0), + "table_1": row_wise(), + "table_2": column_wise(ranks=[0, 1]), + }, + local_size=local_size, + world_size=world_size, + device_type=ctx.device.type, + # pyre-ignore + sharder=sharder, + ) + + sharded_model = _shard_modules( + module=unsharded_model, + plan=ShardingPlan({"": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + # pyre-fixme[6]: For 4th argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[EmbeddingCollectionSharder]`. + sharders=[sharder], + device=ctx.device, + ) + + if not use_apply_optimizer_in_backward: + sharded_model_optimizer = torch.optim.SGD( + sharded_model.parameters(), lr=1.0 + ) + + assert isinstance(sharded_model, ShardedEmbeddingCollection) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.zero_grad() + sharded_model_optimizer.zero_grad() + + unsharded_model_pred_jt_dict = [] + for unsharded_rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_pred_jt_dict.append( + unsharded_model(kjt_input_per_rank[unsharded_rank]) + ) + + # sharded model + # each rank gets a subbatch + sharded_model_pred_jts_dict: Dict[str, JaggedTensor] = sharded_model( + kjt_input_per_rank[ctx.rank] + ) + + unsharded_model_pred_jt_dict_this_rank: Dict[str, JaggedTensor] = ( + unsharded_model_pred_jt_dict[ctx.rank] + ) + + embedding_names = unsharded_model_pred_jt_dict_this_rank.keys() + assert set(unsharded_model_pred_jt_dict_this_rank.keys()) == set( + sharded_model_pred_jts_dict.keys() + ) + + unsharded_loss = [] + sharded_loss = [] + for embedding_name in embedding_names: + unsharded_jt = unsharded_model_pred_jt_dict_this_rank[embedding_name] + sharded_jt = sharded_model_pred_jts_dict[embedding_name] + + torch.testing.assert_close(unsharded_jt.values(), sharded_jt.values()) + torch.testing.assert_close(unsharded_jt.lengths(), sharded_jt.lengths()) + torch.testing.assert_close(unsharded_jt.offsets(), sharded_jt.offsets()) + torch.testing.assert_close( + unsharded_jt.weights_or_none(), sharded_jt.weights_or_none() + ) + + sharded_loss.append(sharded_jt.values().view(-1)) + + for rank in range(ctx.world_size): + for embedding_name in embedding_names: + unsharded_loss.append( + unsharded_model_pred_jt_dict[rank][embedding_name].values().view(-1) + ) + + torch.cat(sharded_loss).sum().backward() + torch.cat(unsharded_loss).sum().backward() + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.step() + sharded_model_optimizer.step() + + for fqn in unsharded_model.state_dict(): + unsharded_state = unsharded_model.state_dict()[fqn] + sharded_state = sharded_model.state_dict()[fqn] + + sharded_param = ( + torch.zeros(size=unsharded_state.shape, device=ctx.device) + if ctx.rank == 0 + else None + ) + if isinstance(sharded_state, ShardedTensor): + sharded_state.gather(out=sharded_param) + elif isinstance(sharded_state, DTensor): + sharded_param = sharded_state.full_tensor() + else: + sharded_param = sharded_state + + if ctx.rank == 0: + torch.testing.assert_close( + unsharded_state, + sharded_param, + msg=f"Did not match for {fqn=} after backward", + ) + + +@skip_if_asan_class +class ShardedEmbeddingCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + # pyre-ignore + @given( + use_apply_optimizer_in_backward=st.booleans(), + use_index_dedup=st.booleans(), + ) + def test_sharding_ebc( + self, + use_apply_optimizer_in_backward: bool, + use_index_dedup: bool, + ) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=4, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_0", "feature_1"], + embedding_dim=8, + num_embeddings=4, + ), + EmbeddingConfig( + name="table_2", + feature_names=["feature_0", "feature_1"], + embedding_dim=8, + num_embeddings=4, + ), + ] + + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] None [2] + # "feature_1" [0, 1] None [2] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0,1,2,3] + # "feature_1" [2, 3] None [2] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([0, 1, 2, 0, 1, 2]), + lengths=torch.LongTensor([2, 0, 1, 2, 0, 1]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2]), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1]), + ), + ] + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + backend="nccl", + use_apply_optimizer_in_backward=use_apply_optimizer_in_backward, + use_index_dedup=use_index_dedup, + ) diff --git a/torchrec/distributed/composable/tests/test_embeddingbag.py b/torchrec/distributed/composable/tests/test_embeddingbag.py new file mode 100644 index 000000000..b7469ceed --- /dev/null +++ b/torchrec/distributed/composable/tests/test_embeddingbag.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest + +from functools import partial +from typing import Any, Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn + +from hypothesis import assume, given, settings, Verbosity +from torch.distributed._tensor.api import DTensor +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec import distributed as trec_dist +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) + +from torchrec.distributed.shard import shard +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, + ModuleSharder, + QuantizedCommCodecs, + ShardingEnv, + ShardingPlan, + ShardingType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import ( + assert_state_buffers_parameters_equal, + skip_if_asan_class, +) + + +def _optional_equals(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: + if t1 is None: + return t2 is None + return t2 is not None and torch.equal(t1, t2) + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + local_size: Optional[int] = None, + is_data_parallel: bool = False, + use_apply_optimizer_in_backward: bool = False, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size, ctx.device.type, local_world_size=ctx.local_size + ), + constraints=constraints, + ) + model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + unsharded_model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + if use_apply_optimizer_in_backward: + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + plan: ShardingPlan = planner.collective_plan(model, [sharder], ctx.pg) + sharded_model = shard( + module=model, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan.get_plan_for_module(""), + sharder=sharder, + device=ctx.device, + ) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer = torch.optim.SGD( + unsharded_model.parameters(), lr=0.01 + ) + sharded_model_optimizer = torch.optim.SGD( + sharded_model.parameters(), lr=0.01 + ) + + assert isinstance(sharded_model, ShardedEmbeddingBagCollection) + + unsharded_model.load_state_dict(copy.deepcopy(initial_state_dict)) + copy_state_dict(sharded_model.state_dict(), copy.deepcopy(initial_state_dict)) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + for _it in range(5): + unsharded_model_params = dict(unsharded_model.named_parameters()) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.zero_grad() + sharded_model_optimizer.zero_grad() + + if is_data_parallel: + for fqn, param in sharded_model.named_parameters(): + assert _optional_equals( + param.grad, unsharded_model_params[fqn].grad + ) + + unsharded_model_pred_kt = [] + for unsharded_rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_pred_kt.append( + unsharded_model(kjt_input_per_rank[unsharded_rank]) + ) + + all_unsharded_preds = [] + for unsharded_rank in range(ctx.world_size): + unsharded_model_pred_kt_mini_batch = unsharded_model_pred_kt[ + unsharded_rank + ].to_dict() + + all_unsharded_preds.extend( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + if unsharded_rank == ctx.rank: + unsharded_model_pred = torch.stack( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + # sharded model + # each rank gets a subbatch + sharded_model_pred_kt = sharded_model( + kjt_input_per_rank[ctx.rank] + ).to_dict() + sharded_model_pred = torch.stack( + [sharded_model_pred_kt[feature] for feature in feature_keys] + ) + + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_model_pred.cpu(), unsharded_model_pred.cpu() + ) + + sharded_model_pred.sum().backward() + + all_unsharded_preds = torch.stack(all_unsharded_preds) + _sum = all_unsharded_preds.sum() + if is_data_parallel: + _sum /= world_size + _sum.backward() + + if is_data_parallel: + for fqn, param in sharded_model.named_parameters(): + assert _optional_equals( + param.grad, unsharded_model_params[fqn].grad + ) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.step() + sharded_model_optimizer.step() + + # check nn.Module APIs look the same + assert_state_buffers_parameters_equal(unsharded_model, sharded_model) + + for fqn in unsharded_model.state_dict(): + unsharded_state = unsharded_model.state_dict()[fqn] + sharded_state = sharded_model.state_dict()[fqn] + if is_data_parallel: + torch.testing.assert_close(unsharded_state, sharded_state) + else: + out = ( + torch.zeros(size=unsharded_state.shape, device=ctx.device) + if ctx.rank == 0 + else None + ) + if isinstance(sharded_state, DTensor): + out = sharded_state.full_tensor() + else: + sharded_state.gather(out=out) + + if ctx.rank == 0: + torch.testing.assert_close( + unsharded_state, + out, + ) + + +class TestEmbeddingBagCollectionSharder(EmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._sharding_type = sharding_type + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +@skip_if_asan_class +class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + ), + use_apply_optimizer_in_backward=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_ebc( + self, + sharding_type: str, + use_apply_optimizer_in_backward: bool, + ) -> None: + + # TODO DistributedDataParallel needs full support of registering fused optims before we can enable this. + assume( + not ( + use_apply_optimizer_in_backward + and sharding_type == ShardingType.DATA_PARALLEL.value + ), + ) + + WORLD_SIZE = 2 + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=4, + init_fn=partial(torch.nn.init.normal_, mean=0.0, std=1.5), + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=4, + init_fn=partial(torch.nn.init.uniform_, a=-0.036, b=0.036), + ), + ] + + constraints = { + "table_0": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LRU, + load_factor=0.1, + reserved_memory=8.0, + precision=DataType.FP16, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.IGNORE, + ), + "table_1": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.2, + reserved_memory=0.0, + precision=DataType.FP16, + ), + enforce_hbm=False, + stochastic_rounding=True, + bounds_check_mode=BoundsCheckMode.NONE, + ), + } + + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] None [2] + # "feature_1" [0, 1] None [2] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0,1,2,3] + # "feature_1" [2, 3] None [2] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([0, 1, 2, 0, 1, 2]), + lengths=torch.LongTensor([2, 0, 1, 2, 0, 1]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2]), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1]), + ), + ] + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + initial_state_dict={ + "embedding_bags.table_0.weight": torch.Tensor( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [4, 4, 4, 4], + [8, 8, 8, 8], + ] + ), + "embedding_bags.table_1.weight": torch.Tensor( + [ + [101, 101, 101, 101], + [102, 102, 102, 102], + [104, 104, 104, 104], + [108, 108, 108, 108], + ] + ), + }, + kjt_input_per_rank=kjt_input_per_rank, + sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type), + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), + constraints=constraints, + is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value), + use_apply_optimizer_in_backward=use_apply_optimizer_in_backward, + ) diff --git a/torchrec/distributed/composable/tests/test_fsdp.py b/torchrec/distributed/composable/tests/test_fsdp.py new file mode 100644 index 000000000..538b5382f --- /dev/null +++ b/torchrec/distributed/composable/tests/test_fsdp.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import tempfile +import unittest + +import torch +from torch import nn + +# from torch.distributed._composable import fully_shard +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor + +from torch.distributed.checkpoint import ( + FileSystemReader, + FileSystemWriter, + load_state_dict, + save_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec.distributed.shard import shard as trec_shard +from torchrec.distributed.sharding_plan import row_wise +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.optim.warmup import WarmupOptimizer, WarmupPolicy, WarmupStage +from torchrec.test_utils import skip_if_asan + + +class FullyShardTest(MultiProcessTestBase): + @classmethod + def _run( # noqa + cls, rank: int, world_size: int, param_path: str, opt_path: str + ) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + with MultiProcessContext(rank, world_size, "nccl") as ctx: + num_float_features = 32 + + m = TestSparseNN( + tables=tables, + num_float_features=num_float_features, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + m.sparse.parameters(), + {"lr": 0.01}, + ) + # pyre-ignore + m.sparse.ebc = trec_shard( + module=m.sparse.ebc, + device=ctx.device, + plan=row_wise(), + ) + # pyre-ignore + m.sparse.weighted_ebc = trec_shard( + module=m.sparse.weighted_ebc, + device=ctx.device, + plan=row_wise(), + ) + m.dense = FSDP( # pyre-ignore + m.dense, + auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), + device_id=ctx.device.index, + ) + m.over = FSDP( + m.over, + auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), + device_id=ctx.device.index, + ) + + dense_opt = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(m.named_parameters(), include=False)), + lambda params: torch.optim.Adam( + params, + lr=0.01, + betas=(0.9, 0.999), + eps=1e-5, + weight_decay=1e-05, + ), + ) + optims = [] + sparse_grad_parameter_names = set() + for name, p in in_backward_optimizer_filter( + m.named_parameters(), include=True + ): + # Add learning rate scheduler + warmup = WarmupOptimizer( + # pyre-ignore + p._in_backward_optimizers[0], + [ + WarmupStage( + policy=WarmupPolicy.LINEAR, + max_iters=1000, + value=0.1, + lr_scale=1.0, + ) + ], + lr=0.01, # initial learning rate + param_name="__sparse_warmup", + ) + optims.append((name, warmup)) + sparse_grad_parameter_names.add(name) + assert len(sparse_grad_parameter_names) == 5 + fused_opt_scheduled = CombinedOptimizer(optims) + dense_opt_scheduled = WarmupOptimizer( + dense_opt, + [ + WarmupStage( + policy=WarmupPolicy.LINEAR, + max_iters=1000, + value=0.15, + lr_scale=1.0, + ) + ], + lr=0.01, + param_name="__dense_warmup", + ) + opt: CombinedOptimizer = CombinedOptimizer( + [fused_opt_scheduled, (dense_opt_scheduled)] + ) + # Runs a dummy optimizer step, which allows to initialize + # optimizer state, which is typically lazy. + # This allows us to do in-place loading of optimizer state from a checkpoint. + # Remark that fused optimizer needs special case as its states are ShardedTensors. + # This is the reason we need to pass the sparse_grad_parameter_names as parameters. + opt.init_state(sparse_grad_parameter_names) + opt.save_param_groups(True) + model_param_names = set(dict(m.named_parameters()).keys()) + opt_param_keys = set(opt.params.keys()) + assert model_param_names.issubset(opt_param_keys) + + ######## run one iteration ######## + _, local_batch = ModelInput.generate( + batch_size=8, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables, + ) + batch = local_batch[0].to(ctx.device) + m(batch)[1].sum().backward() + opt.step() + + state_dict = m.state_dict() + param_writer = FileSystemWriter(path=param_path) + param_reader = FileSystemReader(path=param_path) + save_state_dict(state_dict, param_writer) + + # use FSDP.optim_state_dict() API + opt_state_dict = FullyShardedDataParallel.optim_state_dict(m, opt) + opt_writer = FileSystemWriter(path=opt_path) + opt_reader = FileSystemReader(path=opt_path) + # use Distributed checkpointing API + save_state_dict(opt_state_dict, opt_writer) + + p_sum = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + if isinstance(p, DTensor): + if not p.to_local().local_shards(): + continue + p = p.to_local().local_shards()[0] + p_sum += p.sum() + p.zero_() + assert p.sum() == 0 + o_sum = torch.zeros(1, device=ctx.device) + for p_v in opt.state_dict()["state"].values(): + for name, t in p_v.items(): + if name == "step": + continue + if isinstance(t, ShardedTensor): + if not t.local_shards(): + continue + t = t.local_tensor() + if isinstance(t, DTensor): + if not t.to_local().local_shards(): # pyre-ignore[16] + continue + t = t.to_local().local_shards()[0] + o_sum += t.sum() + t.zero_() + assert t.sum() == 0 + + load_state_dict(state_dict, param_reader) + missing, unexpected = m.load_state_dict(state_dict) + assert len(missing) == 0 and len(unexpected) == 0 + + load_state_dict(opt_state_dict, opt_reader) + # use FSDP.optim_state_dict_to_load() API + new_opt_state_dict = FullyShardedDataParallel.optim_state_dict_to_load( + m, opt, opt_state_dict, is_named_optimizer=True + ) + opt.load_state_dict(new_opt_state_dict) + + p_sum_loaded = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + p_sum_loaded += p.sum() + if isinstance(p, DTensor): + if not p.to_local().local_shards(): + continue + p = p.to_local().local_shards()[0] + assert p_sum.allclose(p_sum_loaded) + + o_sum_loaded = torch.zeros(1, device=ctx.device) + for p_v in opt.state_dict()["state"].values(): + for name, t in p_v.items(): + if name == "step": + continue + if isinstance(t, ShardedTensor): + if not t.local_shards(): + continue + t = t.local_tensor() + if isinstance(t, DTensor): + if not t.to_local().local_shards(): + continue + t = t.to_local().local_shards()[0] + o_sum_loaded += t.sum() + assert o_sum.allclose(o_sum_loaded) + + @skip_if_asan + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_composable_checkpoint(self) -> None: + with tempfile.TemporaryDirectory() as param_path, tempfile.TemporaryDirectory() as opt_path: + self._run_multi_process_test( + callable=self._run, + param_path=param_path, + opt_path=opt_path, + ) diff --git a/torchrec/distributed/composable/tests/test_fused_optim.py b/torchrec/distributed/composable/tests/test_fused_optim.py new file mode 100644 index 000000000..bdd234d9c --- /dev/null +++ b/torchrec/distributed/composable/tests/test_fused_optim.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import unittest + +import torch +from torch import distributed as dist +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec.distributed.shard import shard +from torchrec.distributed.sharding_plan import ( + apply_to_all, + construct_module_sharding_plan, + table_wise, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.optim.warmup import WarmupOptimizer, WarmupPolicy, WarmupStage +from torchrec.test_utils import get_free_port + + +class TestFusedOptim(unittest.TestCase): + def setUp(self) -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + if torch.cuda.is_available(): + self.curr_device = torch.device("cuda:0") + torch.cuda.set_device(self.curr_device) + backend = "nccl" + else: + self.curr_device = torch.device("cpu") + backend = "gloo" + dist.init_process_group(backend=backend) + + def tearDown(self) -> None: + dist.destroy_process_group() + + def test_opt_state_correct(self) -> None: + num_features = 4 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + apply_optimizer_in_backward( + RowWiseAdagrad, + ebc.parameters(), + {"lr": 0.01}, + ) + plan = construct_module_sharding_plan( + ebc, + apply_to_all(ebc, table_wise(rank=0)), + ) + ebc = shard( + module=ebc, + plan=plan, + device=self.curr_device, + ) + for name, param in ebc.named_parameters(): + table_name = name[len("embedding_bags.") : -len("weight") - 1] + self.assertEqual( + param._in_backward_optimizers[0] + .state_dict()["state"][""][f"{table_name}.momentum1"] + .local_tensor() + .data_ptr(), + ebc._optim.state_dict()["state"][f"embedding_bags.{table_name}.weight"][ + f"{table_name}.momentum1" + ] + .local_tensor() + .data_ptr(), + ) + + def test_set_learning_rate(self) -> None: + num_features = 1 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + apply_optimizer_in_backward( + RowWiseAdagrad, + ebc.parameters(), + {"lr": 0.01}, + ) + plan = construct_module_sharding_plan( + ebc, + apply_to_all(ebc, table_wise(rank=0)), + ) + ebc = shard( + module=ebc, + plan=plan, + device=self.curr_device, + ) + for param in ebc.parameters(): + param._in_backward_optimizers = [ + WarmupOptimizer( + param._in_backward_optimizers[0], + [ + WarmupStage( + policy=WarmupPolicy.LINEAR, + max_iters=10000, + value=0.5, + lr_scale=1.0, + ) + ], + param_name="__warmup_state", + ) + ] + param._in_backward_optimizers[0].step() + param._in_backward_optimizers[0].step() + warmup_state = param._in_backward_optimizers[0].state_dict()["state"][ + "__warmup_state" + ] + _iter, _ = warmup_state["warmup"] + self.assertEqual(_iter, 2) + self.assertEqual( + param._in_backward_optimizers[0].param_groups[0]["lr"], 0.05001 + ) diff --git a/torchrec/distributed/composable/tests/test_fused_optim_nccl.py b/torchrec/distributed/composable/tests/test_fused_optim_nccl.py new file mode 100644 index 000000000..52eb6ef03 --- /dev/null +++ b/torchrec/distributed/composable/tests/test_fused_optim_nccl.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import List + +import torch +from torchrec.distributed.shard import shard +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + row_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.optim.optimizers import PartialRowWiseAdam +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad + + +class ShardedFusedOptimizerStateDictTest(MultiProcessTestBase): + @staticmethod + def _test_sharded_fused_optimizer_state_dict( + tables: List[EmbeddingBagConfig], + rank: int, + local_size: int, + world_size: int, + backend: str, + ) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + ebc.embedding_bags["table_0"].weight, + ebc.embedding_bags["table_1"].weight, + ], + {"lr": 0.01}, + ) + apply_optimizer_in_backward( + PartialRowWiseAdam, + [ + ebc.embedding_bags["table_2"].weight, + ebc.embedding_bags["table_3"].weight, + ], + {"lr": 0.02}, + ) + + parameter_sharding_plan = construct_module_sharding_plan( + ebc, + per_param_sharding={ + "table_0": column_wise(ranks=[0, 0, 1, 1]), + "table_1": row_wise(), + "table_2": column_wise(ranks=[0, 1, 0, 1]), + "table_3": row_wise(), + }, + world_size=ctx.world_size, + local_size=ctx.local_size, + device_type=ctx.device.type, + ) + + ebc = shard( + module=ebc, + plan=parameter_sharding_plan, + device=ctx.device, + ) + + ebc.embedding_bags["table_0"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_0.momentum1"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # sharded column, each shard will have rowwise state + else torch.empty((4 * tables[0].num_embeddings,), device=ctx.device) + ), + ) + + ebc.embedding_bags["table_1"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_1.momentum1"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # sharded rowwise + else torch.empty((tables[1].num_embeddings,), device=ctx.device) + ), + ) + + ebc.embedding_bags["table_2"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_2.momentum1"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty( + (tables[2].num_embeddings, tables[2].embedding_dim), + device=ctx.device, + ) + ), + ) + + ebc.embedding_bags["table_2"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_2.exp_avg_sq"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty((4 * tables[2].num_embeddings,), device=ctx.device) + ), + ) + + ebc.embedding_bags["table_3"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_3.momentum1"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # Row wise - with partial rowwise adam, first state is point wise + else torch.empty( + (tables[3].num_embeddings, tables[3].embedding_dim), + device=ctx.device, + ) + ), + ) + + ebc.embedding_bags["table_3"].weight._in_backward_optimizers[ + 0 + ].state_dict()["state"][""]["table_3.exp_avg_sq"].gather( + dst=0, + out=( + None + if ctx.rank != 0 + # Column wise - with partial rowwise adam, first state is point wise + else torch.empty((tables[2].num_embeddings,), device=ctx.device) + ), + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharded_fused_optimizer_state_dict(self) -> None: + WORLD_SIZE = 2 + LOCAL_SIZE = 2 + tables = [ + EmbeddingBagConfig( + num_embeddings=64, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + # test different optimizer table datatypes to ensure optimizer dtype is consistent + data_type=DataType.FP16 if i > 1 else DataType.FP32, + ) + for i in range(4) + ] + + self._run_multi_process_test( + callable=self._test_sharded_fused_optimizer_state_dict, + tables=tables, + backend="nccl", + local_size=LOCAL_SIZE, + world_size=WORLD_SIZE, + ) diff --git a/torchrec/distributed/composable/tests/test_table_batched_embedding_slice.py b/torchrec/distributed/composable/tests/test_table_batched_embedding_slice.py new file mode 100644 index 000000000..b33994c60 --- /dev/null +++ b/torchrec/distributed/composable/tests/test_table_batched_embedding_slice.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + DenseTableBatchedEmbeddingBagsCodegen, +) +from torchrec.distributed.composable.table_batched_embedding_slice import ( + TableBatchedEmbeddingSlice, +) + + +class TestTableBatchedEmbeddingSlice(unittest.TestCase): + def test_is_view(self) -> None: + device = "cpu" if not torch.cuda.is_available() else "cuda" + emb = DenseTableBatchedEmbeddingBagsCodegen( + [(2, 4), (2, 4)], use_cpu=device == "cpu" + ) + first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4) + self.assertEqual(first_table.data_ptr(), emb.weights.data_ptr()) + + def test_copy(self) -> None: + device = "cpu" if not torch.cuda.is_available() else "cuda" + emb = DenseTableBatchedEmbeddingBagsCodegen( + [(2, 4), (2, 4)], use_cpu=device == "cpu" + ) + first_table = TableBatchedEmbeddingSlice(emb.weights, 0, 8, 2, 4) + copied = copy.deepcopy(first_table) + self.assertNotEqual(first_table.data_ptr(), copied.data_ptr()) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index d9faadcf9..4c66511ef 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -5,9 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import itertools import logging -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional import torch import torch.distributed as dist @@ -18,10 +20,16 @@ alltoall_pooled, alltoall_sequence, reduce_scatter_base_pooled, + reduce_scatter_v_per_feature_pooled, reduce_scatter_v_pooled, + variable_batch_alltoall_pooled, ) -from torchrec.distributed.types import Awaitable, NoWait, QuantizedCommCodecs -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.types import Awaitable, QuantizedCommCodecs, rank_device +from torchrec.fx.utils import fx_marker +from torchrec.pt2.checks import is_torchdynamo_compiling +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -35,10 +43,11 @@ # OSS try: - import fbgemm_gpu # @manual # noqa + pass except ImportError: pass + logger: logging.Logger = logging.getLogger() @@ -48,20 +57,21 @@ def _get_recat( stagger: int = 1, device: Optional[torch.device] = None, batch_size_per_rank: Optional[List[int]] = None, -) -> torch.Tensor: +) -> Optional[torch.Tensor]: """ Calculates relevant recat indices required to reorder AlltoAll collective. Args: - local_split (int): how many features in local split. - num_splits (int): how many splits (typically WORLD_SIZE). + local_split (int): number of features in local split. + num_splits (int): number of splits (typically WORLD_SIZE). stagger (int): secondary reordering, (typically 1, but `WORLD_SIZE/LOCAL_WORLD_SIZE` for TWRW). - device (torch.device): device on which buffer will be allocated. - batch_size_per_rank: batch size per rank, needed for variable batch size. + device (Optional[torch.device]): device on which buffer will be allocated. + batch_size_per_rank (Optional[List[int]]): batch size per rank, needed for + variable batch size. Returns: - torch.Tensor: recat tensor. + Optional[torch.Tensor]: recat tensor, None if local rank is empty. Example:: @@ -69,12 +79,15 @@ def _get_recat( # [0, 2, 4, 6, 1, 3, 5, 7] _recat(2, 4, 2) # [0, 4, 2, 6, 1, 5, 3, 7] + _recat(0, 4, 2) + # None """ - with record_function("## all2all_data:recat_permute_gen ##"): - recat: List[int] = [] + with record_function("## all2all_data:recat_permute_gen ##"): if local_split == 0: - return torch.tensor(recat, device=device, dtype=torch.int32) + return None + + recat: List[int] = [] feature_order: List[int] = [ x + num_splits // stagger * y @@ -86,11 +99,19 @@ def _get_recat( for j in feature_order: # range(num_splits): recat.append(i + j * local_split) - # variable batch size - if batch_size_per_rank is not None: + vb_per_rank_condition: bool = False + if not is_torchdynamo_compiling(): + vb_per_rank_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) + + # variable batch per rank + if vb_per_rank_condition: batch_size_per_feature = list( itertools.chain.from_iterable( - itertools.repeat(x, local_split) for x in batch_size_per_rank + itertools.repeat(x, local_split) + # pyre-ignore + for x in batch_size_per_rank ) ) permuted_batch_size_per_feature = [batch_size_per_feature[r] for r in recat] @@ -124,109 +145,267 @@ def _get_recat( return torch.tensor(recat, device=device, dtype=torch.int32) -def _split_lengths( - splits: List[int], keys: List[str], offset_per_key: List[int] -) -> List[int]: - # Calculates lengths [x1, x2, x3, ..., y1, y2], splits [3, ..., 2] - # -> [x1+x2+x3, ..., y1+y2] - length_per_split: List[int] = [] - i = 0 - offset = 0 - for split in splits: - new_offset = offset_per_key[i + split] - length_per_split.append(new_offset - offset) - i += split - offset = new_offset - return length_per_split +class _MergePooledEmbeddingsModuleImpl(torch.nn.Module): + """ + Does the merge_pooled_embeddings operation. Separate module necessary for lowering. + + Args: + device (torch.device): device for fbgemm.merge_pooled_embeddings + """ + + current_device: torch.device + + def __init__( + self, + device: torch.device, + ) -> None: + super().__init__() + self.current_device = device + # the method will be used by inference application to update the + # device information + @torch.jit.export + def set_device(self, device_str: str) -> None: + self.current_device = torch.device(device_str) -class KJTAllToAllIndicesAwaitable(Awaitable[KeyedJaggedTensor]): + def forward(self, tensors: List[torch.Tensor], cat_dim: int) -> torch.Tensor: + """ + Here we assume input tensors are: + [TBE_output_0, ..., TBE_output_(n-1)] + """ + B = tensors[0].size(1 - cat_dim) + return torch.ops.fbgemm.merge_pooled_embeddings( + tensors, + B, + self.current_device, + cat_dim, + ) + + +class MergePooledEmbeddingsModule(torch.nn.Module): """ - Awaitable for KJT indices and weights All2All. + This module is used for merge_pooled_embedding_optimization. + _MergePooledEmbeddingsModuleImpl provides the `set_device` API + to set device at model loading time. Args: - pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - input (KeyedJaggedTensor): input KJT tensor. - lengths (torch.Tensor): output lengths tensor. - splits (List[int]): List of len(pg.size()) which indicates how many features to + device (torch.device): device for fbgemm.merge_pooled_embeddings + """ + + def __init__(self, device: torch.device) -> None: + super().__init__() + self.impl = _MergePooledEmbeddingsModuleImpl(device) + + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self.impl.set_device(device_str) + + def forward(self, tensors: List[torch.Tensor], cat_dim: int) -> torch.Tensor: + """ + Calls _MergePooledEmbeddingsModuleImpl with tensors and cat_dim. + + Args: + tensors (List[torch.Tensor]): list of embedding tensors. + cat_dim (int): which dimension you would like to concatenate on. + + Returns: + torch.Tensor: merged embeddings. + """ + merged_embeddings = self.impl(tensors, cat_dim) + return merged_embeddings + + +class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]): + """ + Awaitable for splits AlltoAll. + + Args: + input_tensors (List[torch.Tensor]): tensor of splits to redistribute. + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + """ + + def __init__( + self, + input_tensors: List[torch.Tensor], + pg: dist.ProcessGroup, + ) -> None: + super().__init__() + self.num_workers: int = pg.size() + + if is_torchdynamo_compiling(): + # TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_ + # https://github.com/pytorch/pytorch/issues/122788 + with record_function("## all2all_data:kjt splits ##"): + input_tensor = torch.stack(input_tensors, dim=1).flatten() + if pg._get_backend_name() == "custom": + self._output_tensor = torch.empty( + [self.num_workers * len(input_tensors)], + device=input_tensors[0].device, + dtype=input_tensors[0].dtype, + ) + + self._output_tensor = input_tensor[ + : input_tensor.size(0) // 2 + ].repeat(2) + else: + self._output_tensor = ( + dist._functional_collectives.all_to_all_single( + input_tensor, + output_split_sizes=None, + input_split_sizes=None, + group=pg, + ) + ) + # To avoid hasattr in _wait_impl to check self._splits_awaitable + # pyre-ignore + self._splits_awaitable = None + else: + with record_function("## all2all_data:kjt splits ##"): + self._output_tensor: torch.Tensor = torch.empty( + [self.num_workers * len(input_tensors)], + device=input_tensors[0].device, + dtype=input_tensors[0].dtype, + ) + input_tensor = torch.stack(input_tensors, dim=1).flatten() + self._splits_awaitable: dist.Work = dist.all_to_all_single( + output=self._output_tensor, + input=input_tensor, + group=pg, + async_op=not is_torchdynamo_compiling(), + ) + + def _wait_impl(self) -> List[List[int]]: + # Can not use is_torchdynamo_compiling(), as every such condition should be independent for compilation with graph breaks. + if isinstance(self._splits_awaitable, dist.Work): + self._splits_awaitable.wait() + + ret = self._output_tensor.view(self.num_workers, -1).T.tolist() + + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + for i in range(len(ret)): + for j in range(len(ret[i])): + torch._check_is_size(ret[i][j]) + + return ret + + +class KJTAllToAllTensorsAwaitable(Awaitable[KeyedJaggedTensor]): + """ + Awaitable for KJT tensors AlltoAll. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + input (KeyedJaggedTensor): input KJT. + splits (List[int]): list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by destination rank. Same for all ranks. + input_splits (List[List[int]]): input splits (number of values each rank will + get) for each tensor in AlltoAll. + output_splits (List[List[int]]): output splits (number of values per rank in + output) for each tensor in AlltoAll. + input_tensors (List[torch.Tensor]): provided KJT tensors (ie. lengths, values) + to redistribute according to splits. + labels (List[str]): labels for each provided tensor. keys (List[str]): KJT keys after AlltoAll. - recat (torch.Tensor): recat tensor for reordering tensor order after AlltoAll. - in_lengths_per_worker (List[int]): number of indices each rank will get. - out_lengths_per_worker (List[int]): number of indices per rank in output. - batch_size_per_rank (List[int]): batch size per rank, need to support variable - batch size. + device (torch.device): device on which buffers will be allocated. + stagger (int): stagger value to apply to recat tensor. + stride_per_rank (Optional[List[int]]): stride per rank in the non variable + batch per feature case. """ def __init__( self, pg: dist.ProcessGroup, input: KeyedJaggedTensor, - lengths: torch.Tensor, splits: List[int], + input_splits: List[List[int]], + output_splits: List[List[int]], + input_tensors: List[torch.Tensor], + labels: List[str], keys: List[str], - recat: torch.Tensor, - in_lengths_per_worker: List[int], - out_lengths_per_worker: List[int], - batch_size_per_rank: List[int], + device: torch.device, + stagger: int, + stride_per_rank: Optional[List[int]], ) -> None: super().__init__() self._workers: int = pg.size() - self._device: torch.device = input.values().device - self._recat = recat - self._splits = splits self._pg: dist.ProcessGroup = pg - self._keys = keys - self._lengths: torch.Tensor = lengths - self._in_lengths_per_worker: List[int] = [] - self._out_lengths_per_worker: List[int] = [] + self._device: torch.device = device self._input = input - self._batch_size_per_rank = batch_size_per_rank + self._splits = splits + self._input_splits_list = input_splits + self._output_splits_list = output_splits + self._input_splits: Dict[str, List[int]] = dict(zip(labels, input_splits)) + self._output_splits: Dict[str, List[int]] = dict(zip(labels, output_splits)) + self._keys = keys + self._stagger = stagger + self._stride_per_rank = stride_per_rank + self._recat: Optional[torch.Tensor] = _get_recat( + local_split=splits[pg.rank()], + num_splits=len(splits), + stagger=stagger, + device=device, + batch_size_per_rank=self._stride_per_rank, + ) if self._workers == 1: return - self._in_lengths_per_worker = in_lengths_per_worker - self._out_lengths_per_worker = out_lengths_per_worker - - in_values = self._input.values().view(-1) - out_values = torch.empty( - sum(self._out_lengths_per_worker), - device=self._device, - dtype=in_values.dtype, - ) - with record_function("## all2all_data:indices ##"): - self._values_awaitable: dist.Work = dist.all_to_all_single( - output=out_values, - input=in_values, - output_split_sizes=self._out_lengths_per_worker, - input_split_sizes=self._in_lengths_per_worker, - group=self._pg, - async_op=True, - ) - - self._values: torch.Tensor = out_values - - self._weights_awaitable: Optional[dist.Work] = None - self._weights: Optional[torch.Tensor] = None - - if self._input.weights_or_none() is not None: - in_weights = self._input.weights().view(-1) - out_weights = torch.empty( - sum(self._out_lengths_per_worker), - device=self._device, - dtype=in_weights.dtype, - ) - with record_function("## all2all_data:weights ##"): - self._weights_awaitable: dist.Work = dist.all_to_all_single( - output=out_weights, - input=in_weights, - output_split_sizes=self._out_lengths_per_worker, - input_split_sizes=self._in_lengths_per_worker, - group=self._pg, - async_op=True, + self._output_tensors: List[torch.Tensor] = [] + self._awaitables: List[dist.Work] = [] + self._world_size: int = self._pg.size() + rank = dist.get_rank(self._pg) + + for input_split, output_split, input_tensor, label in zip( + input_splits, + output_splits, + input_tensors, + labels, + ): + if is_torchdynamo_compiling(): + # TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_ + # https://github.com/pytorch/pytorch/issues/122788 + with record_function(f"## all2all_data:kjt {label} ##"): + if self._pg._get_backend_name() == "custom": + output_tensor = torch.empty( + sum(output_split), + device=self._device, + dtype=input_tensor.dtype, + ) + _l = sum(output_split[:rank]) + _r = _l + output_split[rank] + torch._check(_r < input_tensor.size(0)) + torch._check(_l < input_tensor.size(0)) + torch._check(_l <= _r) + torch._check(2 * (_r - _l) == output_tensor.size(0)) + output_tensor.copy_( + input_tensor[_l:_r].repeat(self._world_size) + ) + else: + output_tensor = dist._functional_collectives.all_to_all_single( + input_tensor, + output_split, + input_split, + pg, + ) + self._output_tensors.append(output_tensor) + else: + output_tensor = torch.empty( + sum(output_split), device=self._device, dtype=input_tensor.dtype ) - self._weights: torch.Tensor = out_weights + with record_function(f"## all2all_data:kjt {label} ##"): + awaitable = dist.all_to_all_single( + output=output_tensor, + input=input_tensor, + output_split_sizes=output_split, + input_split_sizes=input_split, + group=self._pg, + async_op=not is_torchdynamo_compiling(), + ) + + self._output_tensors.append(output_tensor) + self._awaitables.append(awaitable) def _wait_impl(self) -> KeyedJaggedTensor: """ @@ -240,64 +419,37 @@ def _wait_impl(self) -> KeyedJaggedTensor: self._input.sync() return self._input - self._values_awaitable.wait() - - if self._weights_awaitable: - self._weights_awaitable.wait() - - keys = self._keys - lengths = self._lengths - values = self._values - weights = self._weights - - with record_function("## all2all_data:recat_values ##"): - if self._recat.numel() > 0: - if self._recat.numel() == lengths.numel(): # variable batch size - lengths, values, weights = torch.ops.fbgemm.permute_1D_sparse_data( - self._recat, - lengths.view(-1), - values, - weights, - values.numel(), - ) - else: - lengths, values, weights = torch.ops.fbgemm.permute_2D_sparse_data( - self._recat, - lengths.view(self._workers * self._splits[self._pg.rank()], -1), - values, - weights, - values.numel(), - ) - lengths = lengths.view(-1) - - ret = KeyedJaggedTensor.from_lengths_sync( - keys=keys, - values=values, - weights=weights, - lengths=lengths, - stride=sum(self._batch_size_per_rank), + if not is_torchdynamo_compiling(): + for awaitable in self._awaitables: + awaitable.wait() + + return type(self._input).dist_init( + keys=self._keys, + tensors=self._output_tensors, + variable_stride_per_key=self._input.variable_stride_per_key(), + num_workers=self._workers, + recat=self._recat, + stride_per_rank=self._stride_per_rank, + stagger=self._stagger, ) - return ret -class KJTAllToAllLengthsAwaitable(Awaitable[KJTAllToAllIndicesAwaitable]): +class KJTAllToAllSplitsAwaitable(Awaitable[KJTAllToAllTensorsAwaitable]): """ - Awaitable for KJT's lengths AlltoAll. - - wait() waits on lengths AlltoAll, then instantiates `KJTAllToAllIndicesAwaitable` - awaitable where indices and weights AlltoAll will be issued. + Awaitable for KJT tensors splits AlltoAll. Args: pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - input (KeyedJaggedTensor): input KJT tensor + input (KeyedJaggedTensor): input KJT. splits (List[int]): list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by destination rank. Same for all ranks. - keys (List[str]): KJT keys after AlltoAll - stagger (int): stagger value to apply to recat tensor, see `_recat` function for - more detail. - recat (torch.Tensor): recat tensor for reordering tensor order after AlltoAll. - variable_batch_size (bool): whether variable batch size is enabled. + tensor_splits (Dict[str, List[int]]): tensor splits provided by input KJT. + input_tensors (List[torch.Tensor]): provided KJT tensors (ie. lengths, values) + to redistribute according to splits. + keys (List[str]): KJT keys after AlltoAll. + device (torch.device): device on which buffers will be allocated. + stagger (int): stagger value to apply to recat tensor. """ def __init__( @@ -305,127 +457,89 @@ def __init__( pg: dist.ProcessGroup, input: KeyedJaggedTensor, splits: List[int], + labels: List[str], + tensor_splits: List[List[int]], + input_tensors: List[torch.Tensor], keys: List[str], + device: torch.device, stagger: int, - recat: torch.Tensor, - variable_batch_size: bool = False, ) -> None: super().__init__() self._workers: int = pg.size() self._pg: dist.ProcessGroup = pg - self._device: torch.device = input.values().device + self._device: torch.device = device self._input = input - self._keys = keys - self._lengths: torch.Tensor = input.lengths() self._splits = splits - self._recat: torch.Tensor = recat - self._in_lengths_per_worker: List[int] = [] - self._variable_batch_size = variable_batch_size - dim_0 = splits[pg.rank()] - dim_1 = input.stride() - self._batch_size_per_rank: List[int] = [dim_1] * self._workers + self._labels = labels + self._input_splits = tensor_splits + self._input_tensors = input_tensors + self._keys = keys + self._stagger = stagger + self._output_splits: List[List[int]] = self._input_splits + self._stride_per_rank: Optional[List[int]] = ( + None + if self._input.variable_stride_per_key() + else [self._input.stride()] * self._workers + ) if self._workers == 1: return - if variable_batch_size: - batch_size_per_rank_tensor = torch.empty( - self._workers, - device=self._device, - dtype=torch.torch.int32, - ) - local_batch_sizes = torch.tensor( - [dim_1] * self._workers, - device=self._device, - dtype=torch.torch.int32, - ) - with record_function("## all2all_data: Batch size ##"): - dist.all_to_all_single( - output=batch_size_per_rank_tensor, - input=local_batch_sizes, - output_split_sizes=[1] * self._workers, - input_split_sizes=[1] * self._workers, - group=self._pg, - async_op=False, - ) - self._batch_size_per_rank = batch_size_per_rank_tensor.cpu().tolist() - self._recat = _get_recat( - local_split=dim_0, - num_splits=len(splits), - stagger=stagger, - device=self._device, - batch_size_per_rank=self._batch_size_per_rank, + input_tensors = [ + torch.tensor(splits, device=device) for splits in self._input_splits + ] + if not self._input.variable_stride_per_key(): + input_tensors.append( + torch.tensor([input.stride()] * self._workers, device=device) ) - else: - assert self._recat is not None - in_lengths = input.lengths().view(-1) - out_lengths = torch.empty( - dim_0 * sum(self._batch_size_per_rank), - device=self._device, - dtype=in_lengths.dtype, + self._splits_awaitable = SplitsAllToAllAwaitable( + input_tensors, + self._pg, ) - self._lengths = out_lengths - with record_function("## all2all_data:split length ##"): - self._in_lengths_per_worker = _split_lengths( - splits, input.keys(), input.offset_per_key() - ) - - self._output_split_sizes: List[int] = [ - dim_0 * B_rank for B_rank in self._batch_size_per_rank - ] - with record_function("## all2all_data:lengths ##"): - self._lengths_awaitable: dist.Work = dist.all_to_all_single( - output=out_lengths, - input=in_lengths, - output_split_sizes=self._output_split_sizes, - input_split_sizes=[split * dim_1 for split in self._splits], - group=self._pg, - async_op=True, - ) - def _wait_impl(self) -> KJTAllToAllIndicesAwaitable: + def _wait_impl(self) -> KJTAllToAllTensorsAwaitable: """ Overwrites wait function as we don't handle callbacks here. Returns: - KJTAllToAllIndicesAwaitable. + KJTAllToAllTensorsAwaitable. """ - kjt = self._input - out_lengths_per_worker: List[int] = [] + if self._workers > 1: - self._lengths_awaitable.wait() - if self._variable_batch_size: - with record_function("## all2all_data:split length for a2a ##"): - lengths_per_rank: List[torch.Tensor] = list( - self._lengths.split(self._output_split_sizes) - ) - out_lengths_per_worker = ( - torch.cat( - [ - length.sum(keepdim=True, dim=0) - for length in lengths_per_rank - ], - ) - .cpu() - .tolist() - ) + output_list = self._splits_awaitable.wait() + if self._input.variable_stride_per_key(): + self._output_splits = output_list else: - out_lengths_per_worker = ( - self._lengths.view(self._workers, -1).sum(dim=1).cpu().tolist() - ) + self._output_splits = output_list[:-1] + self._stride_per_rank = output_list[-1] + + if is_torchdynamo_compiling(): + rank: int = self._pg.rank() + for i in range(len(self._output_splits)): + for j in range(len(self._output_splits[i])): + torch._check_is_size(self._output_splits[i][j]) + torch._check( + self._output_splits[i][rank] == self._input_splits[i][rank] + ) + if self._stride_per_rank is not None: + # pyre-ignore + for i in range(len(self._stride_per_rank)): + # pyre-ignore + torch._check_is_size(self._stride_per_rank[i]) - ret = KJTAllToAllIndicesAwaitable( + return KJTAllToAllTensorsAwaitable( pg=self._pg, - input=kjt, - lengths=self._lengths, + input=self._input, splits=self._splits, + input_splits=self._input_splits, + output_splits=self._output_splits, + input_tensors=self._input_tensors, + labels=self._labels, keys=self._keys, - recat=self._recat, - in_lengths_per_worker=self._in_lengths_per_worker, - out_lengths_per_worker=out_lengths_per_worker, - batch_size_per_rank=self._batch_size_per_rank, + device=self._device, + stagger=self._stagger, + stride_per_rank=self._stride_per_rank, ) - return ret class KJTAllToAll(nn.Module): @@ -433,24 +547,26 @@ class KJTAllToAll(nn.Module): Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits. Implementation utilizes AlltoAll collective as part of torch.distributed. - Requires two collective calls, one to transmit final tensor lengths (to allocate - correct space), and one to transmit actual sparse values. + + The input provides the necessary tensors and input splits to distribute. + The first collective call in `KJTAllToAllSplitsAwaitable` will transmit output + splits (to allocate correct space for tensors) and batch size per rank. The + following collective calls in `KJTAllToAllTensorsAwaitable` will transmit the actual + tensors asynchronously. Args: pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. splits (List[int]): List of len(pg.size()) which indicates how many features to - send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by + send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by destination rank. Same for all ranks. - device (Optional[torch.device]): device on which buffers will be allocated. - stagger (int): stagger value to apply to recat tensor, see `_recat` function for - more detail. - variable_batch_size (bool): whether variable batch size in each rank is enabled. + stagger (int): stagger value to apply to recat tensor, see `_get_recat` function + for more detail. Example:: keys=['A','B','C'] splits=[2,1] - kjtA2A = KJTAllToAll(pg, splits, device) + kjtA2A = KJTAllToAll(pg, splits) awaitable = kjtA2A(rank0_input) # where: @@ -486,41 +602,29 @@ def __init__( self, pg: dist.ProcessGroup, splits: List[int], - device: Optional[torch.device] = None, stagger: int = 1, - variable_batch_size: bool = False, ) -> None: super().__init__() - assert len(splits) == pg.size() + torch._check(len(splits) == pg.size()) self._pg: dist.ProcessGroup = pg self._splits = splits - self._no_dist: bool = all(s == 0 for s in splits) self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) self._stagger = stagger - self._variable_batch_size = variable_batch_size - self.register_buffer( - "_recat", - _get_recat( - local_split=splits[pg.rank()], - num_splits=len(splits), - stagger=stagger, - device=device, - ), - ) def forward( self, input: KeyedJaggedTensor - ) -> Awaitable[KJTAllToAllIndicesAwaitable]: + ) -> Awaitable[KJTAllToAllTensorsAwaitable]: """ Sends input to relevant `ProcessGroup` ranks. - First wait will have lengths results and issue indices/weights AlltoAll. - Second wait will have indices/weights results. + + The first wait will get the output splits for the provided tensors and issue + tensors AlltoAll. The second wait will get the tensors. Args: - input (KeyedJaggedTensor): input KeyedJaggedTensor of values to distribute. + input (KeyedJaggedTensor): `KeyedJaggedTensor` of values to distribute. Returns: - Awaitable[KeyedJaggedTensor]: awaitable of a KeyedJaggedTensor. + Awaitable[KJTAllToAllTensorsAwaitable]: awaitable of a `KJTAllToAllTensorsAwaitable`. """ with torch.no_grad(): @@ -530,14 +634,16 @@ def forward( self._splits_cumsum[rank] : self._splits_cumsum[rank + 1] ] - return KJTAllToAllLengthsAwaitable( + return KJTAllToAllSplitsAwaitable( pg=self._pg, input=input, splits=self._splits, + labels=input.dist_labels(), + tensor_splits=input.dist_splits(self._splits), + input_tensors=input.dist_tensors(), keys=local_keys, + device=input.device(), stagger=self._stagger, - recat=self._recat, - variable_batch_size=self._variable_batch_size, ) @@ -552,19 +658,33 @@ class KJTOneToAll(nn.Module): splits (List[int]): lengths of features to split the `KeyJaggedTensor` features into before copying them. world_size (int): number of devices in the topology. + device (torch.device): the device on which the KJTs will be allocated. """ def __init__( self, splits: List[int], world_size: int, + device: Optional[torch.device] = None, ) -> None: super().__init__() self._splits = splits self._world_size = world_size + + if get_propogate_device(): + self._device_type: str = ( + "cpu" if device is None else device.type + ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved + else: + # BUG: device will default to cuda if cpu specified + self._device_type: str = ( + device.type + if device is not None and device.type in {"meta", "cuda", "mtia"} + else "cuda" + ) assert self._world_size == len(splits) - def forward(self, kjt: KeyedJaggedTensor) -> Awaitable[List[KeyedJaggedTensor]]: + def forward(self, kjt: KeyedJaggedTensor) -> KJTList: """ Splits features first and then sends the slices to the corresponding devices. @@ -574,13 +694,22 @@ def forward(self, kjt: KeyedJaggedTensor) -> Awaitable[List[KeyedJaggedTensor]]: Returns: Awaitable[List[KeyedJaggedTensor]]: awaitable of `KeyedJaggedTensor` splits. """ - + fx_marker("KJT_ONE_TO_ALL_FORWARD_BEGIN", kjt) kjts: List[KeyedJaggedTensor] = kjt.split(self._splits) dist_kjts = [ - split_kjt.to(torch.device("cuda", rank), non_blocking=True) - for rank, split_kjt in enumerate(kjts) + ( + kjts[rank] + if self._device_type == "meta" + else kjts[rank].to( + rank_device(self._device_type, rank), + non_blocking=True, + ) + ) + for rank in range(self._world_size) ] - return NoWait(dist_kjts) + ret = KJTList(dist_kjts) + fx_marker("KJT_ONE_TO_ALL_FORWARD_END", kjt) + return ret class PooledEmbeddingsAwaitable(Awaitable[torch.Tensor]): @@ -630,6 +759,7 @@ class PooledEmbeddingsAllToAll(nn.Module): device (Optional[torch.device]): device on which buffers will be allocated. callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]): callback functions. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Example:: @@ -672,7 +802,9 @@ def __init__( ) def forward( - self, local_embs: torch.Tensor, batch_size_per_rank: Optional[List[int]] = None + self, + local_embs: torch.Tensor, + batch_size_per_rank: Optional[List[int]] = None, ) -> PooledEmbeddingsAwaitable: """ Performs AlltoAll pooled operation on pooled embeddings tensor. @@ -686,9 +818,7 @@ def forward( PooledEmbeddingsAwaitable: awaitable of pooled embeddings. """ - if local_embs.numel() == 0: - local_embs.view(local_embs.size(0) * self._pg.size(), 0) - if batch_size_per_rank is None: + if not batch_size_per_rank: B_global = local_embs.size(0) assert ( B_global % self._pg.size() == 0 @@ -717,6 +847,169 @@ def callbacks(self) -> List[Callable[[torch.Tensor], torch.Tensor]]: return self._callbacks +class VariableBatchPooledEmbeddingsAllToAll(nn.Module): + """ + Shards batches and collects keys of tensor with a `ProcessGroup` according to + `dim_sum_per_rank`. + + Implementation utilizes `variable_batch_alltoall_pooled` operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + emb_dim_per_rank_per_feature (List[List[int]]): embedding dimensions per rank + per feature. + device (Optional[torch.device]): device on which buffers will be allocated. + callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]): callback + functions. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. + + Example:: + + kjt_split = [1, 2] + emb_dim_per_rank_per_feature = [[2], [3, 3]] + a2a = VariableBatchPooledEmbeddingsAllToAll( + pg, emb_dim_per_rank_per_feature, device + ) + + t0 = torch.rand(6) # 2 * (2 + 1) + t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2) + # r0_batch_size r1_batch_size + # f_0: 2 1 + ----------------------------------------- + # f_1: 1 2 + # f_2: 3 2 + r0_batch_size_per_rank_per_feature = [[2], [1]] + r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]] + r0_batch_size_per_feature_pre_a2a = [2, 1, 3] + r1_batch_size_per_feature_pre_a2a = [1, 2, 2] + + rank0_output = a2a( + t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a + ).wait() + rank1_output = a2a( + t1, r1_batch_size_per_rank_per_feature, r1_batch_size_per_feature_pre_a2a + ).wait() + + # input splits: + # r0: [2*2, 1*2] + # r1: [1*3 + 3*3, 2*3 + 2*3] + + # output splits: + # r0: [2*2, 1*3 + 3*3] + # r1: [1*2, 2*3 + 2*3] + + print(rank0_output.size()) + # torch.Size([16]) + # 2*2 + 1*3 + 3*3 + print(rank1_output.size()) + # torch.Size([14]) + # 1*2 + 2*3 + 2*3 + """ + + def __init__( + self, + pg: dist.ProcessGroup, + emb_dim_per_rank_per_feature: List[List[int]], + device: Optional[torch.device] = None, + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, + codecs: Optional[QuantizedCommCodecs] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once("torchrec.distributed.vbe") + self._pg = pg + self._emb_dim_per_rank_per_feature = emb_dim_per_rank_per_feature + self._callbacks: List[Callable[[torch.Tensor], torch.Tensor]] = [] + if callbacks is not None: + self._callbacks = callbacks + self._codecs = codecs + + def forward( + self, + local_embs: torch.Tensor, + batch_size_per_rank_per_feature: List[List[int]], + batch_size_per_feature_pre_a2a: List[int], + ) -> PooledEmbeddingsAwaitable: + """ + Performs AlltoAll pooled operation with variable batch size per feature on a + pooled embeddings tensor. + + Args: + local_embs (torch.Tensor): tensor of values to distribute. + batch_size_per_rank_per_feature (List[List[int]]): batch size per rank per + feature, post a2a. Used to get the input splits. + batch_size_per_feature_pre_a2a (List[int]): local batch size before + scattering, used to get the output splits. + Ordered by rank_0 feature, rank_1 feature, ... + + Returns: + PooledEmbeddingsAwaitable: awaitable of pooled embeddings. + """ + + tensor_awaitable = variable_batch_alltoall_pooled( + a2a_pooled_embs_tensor=local_embs, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature, + batch_size_per_feature_pre_a2a=batch_size_per_feature_pre_a2a, + emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature, + group=self._pg, + codecs=self._codecs, + ) + + pooled_embedding_awaitable = PooledEmbeddingsAwaitable( + tensor_awaitable=tensor_awaitable, + ) + pooled_embedding_awaitable.callbacks.extend(self._callbacks) + + return pooled_embedding_awaitable + + @property + def callbacks(self) -> List[Callable[[torch.Tensor], torch.Tensor]]: + return self._callbacks + + +class EmbeddingsAllToOneReduce(nn.Module): + """ + Merges the pooled embedding tensor on each device into single tensor. + + Args: + device (torch.device): device on which buffer will be allocated. + world_size (int): number of devices in the topology. + """ + + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._device = device + self._world_size = world_size + + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + + def forward( + self, + tensors: List[torch.Tensor], + ) -> torch.Tensor: + """ + Performs AlltoOne operation with Reduce on pooled embeddings tensors. + + Args: + tensors (List[torch.Tensor]): list of embedding tensors. + + Returns: + Awaitable[torch.Tensor]: awaitable of the reduced embeddings. + """ + assert len(tensors) == self._world_size + return torch.ops.fbgemm.sum_reduce_to_one( + tensors, + self._device, + ) + + class EmbeddingsAllToOne(nn.Module): """ Merges the pooled/sequence embedding tensor on each device into single tensor. @@ -735,11 +1028,19 @@ def __init__( cat_dim: int, ) -> None: super().__init__() - self._device = device + self.merge_pooled_embeddings = MergePooledEmbeddingsModule(device) self._world_size = world_size self._cat_dim = cat_dim + self._device = device - def forward(self, tensors: List[torch.Tensor]) -> Awaitable[torch.Tensor]: + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + self.merge_pooled_embeddings.set_device(device_str) + + def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: """ Performs AlltoOne operation on pooled/sequence embeddings tensors. @@ -749,17 +1050,19 @@ def forward(self, tensors: List[torch.Tensor]) -> Awaitable[torch.Tensor]: Returns: Awaitable[torch.Tensor]: awaitable of the merged embeddings. """ + assert len(tensors) <= self._world_size - assert len(tensors) == self._world_size - non_cat_size = tensors[0].size(1 - self._cat_dim) - return NoWait( - torch.ops.fbgemm.merge_pooled_embeddings( + is_target_device_cpu: bool = self._device.type == "cpu" + + if self._world_size == 1: + merge = tensors[0] + else: + merge = self.merge_pooled_embeddings( tensors, - non_cat_size, - self._device, self._cat_dim, ) - ) + + return merge if not is_target_device_cpu else merge.to(self._device) class SeqEmbeddingsAllToOne(nn.Module): @@ -782,7 +1085,13 @@ def __init__( self._device = device self._world_size = world_size - def forward(self, tensors: List[torch.Tensor]) -> Awaitable[List[torch.Tensor]]: + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + + def forward(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]: """ Performs AlltoOne operation on pooled embeddings tensors. @@ -794,11 +1103,9 @@ def forward(self, tensors: List[torch.Tensor]) -> Awaitable[List[torch.Tensor]]: """ assert len(tensors) == self._world_size - return NoWait( - torch.ops.fbgemm.all_to_one_device( - tensors, - self._device, - ) + return torch.ops.fbgemm.all_to_one_device( + tensors, + self._device, ) @@ -808,18 +1115,18 @@ class PooledEmbeddingsReduceScatter(nn.Module): embedding communication in row-wise and twrw sharding. For pooled embeddings, we have a local model-parallel output tensor with a layout of - [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension - across batches. We split tensor along the first dimension into unequal chunks (tensor - slices of different buckets) according to input_splits and reduce them into the output - tensor and scatter the results for corresponding ranks. + `[num_buckets x batch_size, dimension]`. We need to sum over `num_buckets` dimension + across batches. We split the tensor along the first dimension into unequal chunks + (tensor slices of different buckets) according to `input_splits` and reduce them + into the output tensor and scatter the results for corresponding ranks. The class returns the async `Awaitable` handle for pooled embeddings tensor. - The reduce-scatter-v is only available for NCCL backend. + The `reduce-scatter-v` operation is only available for NCCL backend. Args: - pg (dist.ProcessGroup): The process group that the reduce-scatter communication + pg (dist.ProcessGroup): the process group that the reduce-scatter communication happens within. - codecs (Optional[QuantizedCommCodecs]): Quantization codec + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Example:: @@ -848,38 +1155,123 @@ def forward( Performs reduce scatter operation on pooled embeddings tensor. Args: - local_embs (torch.Tensor): tensor of shape [num_buckets x batch_size, dimension]. - input_splits (Optional[List[int]]): list of splits for local_embs dim0. + local_embs (torch.Tensor): tensor of shape + `[num_buckets * batch_size, dimension]`. + input_splits (Optional[List[int]]): list of splits for `local_embs` dim 0. Returns: PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension]. """ - if input_splits and len(set(input_splits)) > 1: - tensor_awaitable = reduce_scatter_v_pooled( - local_embs, input_splits, self._pg, codecs=self._codecs - ) + # Dynamo can not trace through data dependent condition: len(set(input_splits)) > 1 + if is_torchdynamo_compiling(): + if input_splits is not None: + tensor_awaitable = reduce_scatter_v_pooled( + local_embs, input_splits, self._pg, codecs=self._codecs + ) + else: + tensor_awaitable = reduce_scatter_base_pooled( + local_embs, self._pg, codecs=self._codecs + ) else: - tensor_awaitable = reduce_scatter_base_pooled( - local_embs, self._pg, codecs=self._codecs - ) + if input_splits and len(set(input_splits)) > 1: + tensor_awaitable = reduce_scatter_v_pooled( + local_embs, input_splits, self._pg, codecs=self._codecs + ) + else: + tensor_awaitable = reduce_scatter_base_pooled( + local_embs, self._pg, codecs=self._codecs + ) + return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable) + + +class VariableBatchPooledEmbeddingsReduceScatter(nn.Module): + """ + The module class that wraps reduce-scatter communication primitives for pooled + embedding communication of variable batch in rw and twrw sharding. + + For variable batch per feature pooled embeddings, we have a local model-parallel + output tensor with a 1d layout of the total sum of batch sizes per rank per feature + multiplied by corresponding embedding dim `[batch_size_r0_f0 * emb_dim_f0 + ...)]`. + We split the tensor into unequal chunks by rank according to + `batch_size_per_rank_per_feature` and corresponding `embedding_dims` and reduce them + into the output tensor and scatter the results for corresponding ranks. + + The class returns the async `Awaitable` handle for pooled embeddings tensor. + The `reduce-scatter-v` operation is only available for NCCL backend. + + Args: + pg (dist.ProcessGroup): the process group that the reduce-scatter communication + happens within. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. + + Example:: + + init_distributed(rank=rank, size=2, backend="nccl") + pg = dist.new_group(backend="nccl") + input = torch.randn(52) + batch_size_per_rank_per_feature = [[1, 3], [2, 2]] + embedding_dims = [4, 8] + m = VariableBatchPooledEmbeddingsReduceScatter(pg) + output = m(input, batch_size_per_rank_per_feature) + tensor = output.wait() + """ + + def __init__( + self, + pg: dist.ProcessGroup, + codecs: Optional[QuantizedCommCodecs] = None, + ) -> None: + super().__init__() + self._pg = pg + self._codecs = codecs + + def forward( + self, + local_embs: torch.Tensor, + batch_size_per_rank_per_feature: List[List[int]], + embedding_dims: List[int], + ) -> PooledEmbeddingsAwaitable: + """ + Performs reduce scatter operation on pooled embeddings tensor. + + Args: + local_embs (torch.Tensor): tensor of shape + `[num_buckets * batch_size, dimension]`. + batch_size_per_rank_per_feature (List[List[int]]): batch size per rank per + feature used to determine input splits. + embedding_dims (List[int]): embedding dimensions per feature used to + determine input splits. + + Returns: + PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension]. + """ + + tensor_awaitable = reduce_scatter_v_per_feature_pooled( + input=local_embs, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature, + embedding_dims=embedding_dims, + group=self._pg, + codecs=self._codecs, + ) return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable) class PooledEmbeddingsAllGather(nn.Module): """ - The module class that wraps all-gather communication primitive for pooled - embedding communication + The module class that wraps the all-gather communication primitive for pooled + embedding communication. - We have a local input tensor with a layout of - [batch_size, dimension]. We need to gather input tensors from all ranks into a flatten output tensor. + Provided a local input tensor with a layout of `[batch_size, dimension]`, we want to + gather input tensors from all ranks into a flattened output tensor. The class returns the async `Awaitable` handle for pooled embeddings tensor. The all-gather is only available for NCCL backend. Args: - pg (dist.ProcessGroup): The process group that the all-gather communication + pg (dist.ProcessGroup): the process group that the all-gather communication happens within. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Example:: @@ -905,7 +1297,8 @@ def forward(self, local_emb: torch.Tensor) -> PooledEmbeddingsAwaitable: Performs reduce scatter operation on pooled embeddings tensor. Args: - local_embs (torch.Tensor): tensor of shape [num_buckets x batch_size, dimension]. + local_emb (torch.Tensor): tensor of shape + `[num_buckets x batch_size, dimension]`. Returns: PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension]. @@ -937,15 +1330,13 @@ def __init__( ) -> None: super().__init__() self._tensor_awaitable = tensor_awaitable - self._unbucketize_permute_tensor = unbucketize_permute_tensor - self._embedding_dim = embedding_dim - if self._unbucketize_permute_tensor is not None: + if unbucketize_permute_tensor is not None: self.callbacks.append( lambda ret: torch.index_select( - ret.view(-1, self._embedding_dim), + ret.view(-1, embedding_dim), 0, - self._unbucketize_permute_tensor, + unbucketize_permute_tensor, ) ) @@ -970,6 +1361,7 @@ class SequenceEmbeddingsAllToAll(nn.Module): happens within. features_per_rank (List[int]): list of number of features per rank. device (Optional[torch.device]): device on which buffers will be allocated. + codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. Example:: @@ -1037,10 +1429,11 @@ def forward( lengths (torch.Tensor): lengths of sparse features after AlltoAll. input_splits (List[int]): input splits of AlltoAll. output_splits (List[int]): output splits of AlltoAll. - unbucketize_permute_tensor (Optional[torch.Tensor]): stores the permute order - of the KJT bucketize (for row-wise sharding only). - batch_size_per_rank: (Optional[List[int]]): batch size per rank - sparse_features_recat (Optional[torch.Tensor]): recat tensor used for sparse feature input dist + unbucketize_permute_tensor (Optional[torch.Tensor]): stores the permute + order of the KJT bucketize (for row-wise sharding only). + batch_size_per_rank: (Optional[List[int]]): batch size per rank. + sparse_features_recat (Optional[torch.Tensor]): recat tensor used for sparse + feature input dist. Must be provided if using variable batch size. Returns: SequenceEmbeddingsAwaitable: awaitable of sequence embeddings. @@ -1049,15 +1442,6 @@ def forward( variable_batch_size = ( batch_size_per_rank is not None and len(set(batch_size_per_rank)) > 1 ) - if variable_batch_size: - if sparse_features_recat is None: - sparse_features_recat = _get_recat( - local_split=self._local_split, - num_splits=self._num_splits, - device=local_embs.device, - stagger=1, - batch_size_per_rank=batch_size_per_rank, - ) if sparse_features_recat is not None: forward_recat_tensor = torch.ops.fbgemm.invert_permute( @@ -1084,3 +1468,368 @@ def forward( unbucketize_permute_tensor=unbucketize_permute_tensor, embedding_dim=local_embs.shape[1], ) + + +class JaggedTensorAllToAll(Awaitable[JaggedTensor]): + """ + Redistributes `JaggedTensor` to a `ProcessGroup` along the batch dimension according + to the number of items to send and receive. The number of items to send + must be known ahead of time on each rank. This is currently used for sharded + KeyedJaggedTensorPool, after distributing the number of IDs to lookup or update on + each rank. + + Implementation utilizes AlltoAll collective as part of torch.distributed. + + Args: + jt (JaggedTensor): JaggedTensor to distribute. + num_items_to_send (int): Number of items to send. + num_items_to_receive (int): Number of items to receive from all other ranks. + This must be known ahead of time on each rank, usually via another AlltoAll. + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + """ + + def __init__( + self, + jt: JaggedTensor, + num_items_to_send: torch.Tensor, + num_items_to_receive: torch.Tensor, + pg: dist.ProcessGroup, + ) -> None: + super().__init__() + self._workers: int = pg.size() + + self._dist_lengths: torch.Tensor = torch.empty( + sum(num_items_to_receive), + device=jt.lengths().device, + dtype=jt.lengths().dtype, + ) + + dist.all_to_all_single( + self._dist_lengths, + jt.lengths(), + output_split_sizes=num_items_to_receive.tolist(), + input_split_sizes=num_items_to_send.tolist(), + group=pg, + async_op=False, + ) + + # below will calculate chunks sums e.g. + # num_batches_to_receive = [2,2] + # lengths = [2,3,1,1] + # output_splits = [5,2] + dist_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + num_items_to_receive + ) + value_output_splits = torch.ops.fbgemm.segment_sum_csr( + 1, + dist_id_offsets, + self._dist_lengths, + ).tolist() + + self._dist_values: torch.Tensor = torch.empty( + sum(value_output_splits), + dtype=jt.values().dtype, + device=jt.values().device, + ) + + # same as above, calculate chunk sums + id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_items_to_send) + value_input_splits = torch.ops.fbgemm.segment_sum_csr( + 1, + id_offsets, + jt.lengths(), + ).tolist() + + self._dist_values_req: dist.Work = dist.all_to_all_single( + self._dist_values, + jt.values(), + output_split_sizes=value_output_splits, + input_split_sizes=value_input_splits, + group=pg, + async_op=True, + ) + + self._dist_weights: Optional[torch.Tensor] = None + self._dist_weights_req: Optional[dist.Work] = None + if jt.weights_or_none() is not None: + self._dist_weights = torch.empty( + sum(value_output_splits), + dtype=jt.weights().dtype, + device=jt.weights().device, + ) + + self._dist_weights_req = dist.all_to_all_single( + self._dist_weights, + jt.weights(), + output_split_sizes=value_output_splits, + input_split_sizes=value_input_splits, + group=pg, + async_op=True, + ) + + def _wait_impl(self) -> JaggedTensor: + self._dist_values_req.wait() + if self._dist_weights_req is not None: + self._dist_weights_req.wait() + + return JaggedTensor( + values=self._dist_values, + lengths=self._dist_lengths, + weights=self._dist_weights, + ) + + +class TensorAllToAllValuesAwaitable(Awaitable[torch.Tensor]): + def __init__( + self, + pg: dist.ProcessGroup, + input: torch.Tensor, + input_splits: torch.Tensor, + output_splits: torch.Tensor, + device: torch.device, + ) -> None: + super().__init__() + self._workers: int = pg.size() + self._device: torch.device = device + self._input = input + + self._dist_values: torch.Tensor + if self._workers == 1: + self._dist_values = input + return + else: + if input.dim() > 1: + self._dist_values = torch.empty( + (sum(output_splits), input.shape[1]), + device=self._device, + dtype=input.dtype, + ) + else: + self._dist_values = torch.empty( + sum(output_splits), device=self._device, dtype=input.dtype + ) + + with record_function("## all2all_data:ids ##"): + self._values_awaitable: dist.Work = dist.all_to_all_single( + output=self._dist_values, + input=input, + output_split_sizes=output_splits.tolist(), + input_split_sizes=input_splits.tolist(), + group=pg, + async_op=True, + ) + + def _wait_impl(self) -> torch.Tensor: + if self._workers > 1: + self._values_awaitable.wait() + return self._dist_values + + +class TensorAllToAllSplitsAwaitable(Awaitable[TensorAllToAllValuesAwaitable]): + def __init__( + self, + pg: dist.ProcessGroup, + input: torch.Tensor, + splits: torch.Tensor, + device: torch.device, + ) -> None: + super().__init__() + self._workers: int = pg.size() + self._pg: dist.ProcessGroup = pg + self._device: torch.device = device + self._input = input + self._input_splits = splits + + self._output_splits: torch.Tensor + if self._workers == 1: + self._output_splits = splits + return + else: + self._output_splits = torch.empty( + [self._workers], + device=device, + dtype=torch.int, + ) + + with record_function("## all2all_data:ids splits ##"): + self._num_ids_awaitable: dist.Work = dist.all_to_all_single( + output=self._output_splits, + input=splits, + group=pg, + async_op=True, + ) + + def _wait_impl(self) -> TensorAllToAllValuesAwaitable: + if self._workers > 1: + self._num_ids_awaitable.wait() + + return TensorAllToAllValuesAwaitable( + pg=self._pg, + input=self._input, + input_splits=self._input_splits, + output_splits=self._output_splits, + device=self._device, + ) + + +class TensorValuesAllToAll(nn.Module): + """ + Redistributes torch.Tensor to a `ProcessGroup` according to input and output splits. + + Implementation utilizes AlltoAll collective as part of torch.distributed. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + + Example:: + tensor_vals_A2A = TensorValuesAllToAll(pg) + input_splits = torch.Tensor([1,2]) on rank0 and torch.Tensor([1,1]) on rank1 + output_splits = torch.Tensor([1,1]) on rank0 and torch.Tensor([2,1]) on rank1 + awaitable = tensor_vals_A2A(rank0_input, input_splits, output_splits) + + where: + rank0_input is 3 x 3 torch.Tensor holding + [ + [V1, V2, V3], + [V4, V5, V6], + [V7, V8, V9], + ] + + rank1_input is 2 x 3 torch.Tensor holding + [ + [V10, V11, V12], + [V13, V14, V15], + ] + + rank0_output = awaitable.wait() + + # where: + # rank0_output is torch.Tensor holding + [ + [V1, V2, V3], + [V10, V11, V12], + ] + + # rank1_output is torch.Tensor holding + [ + [V1, V2, V3], + [V4, V5, V6], + [V7, V8, V9], + ] + """ + + def __init__( + self, + pg: dist.ProcessGroup, + ) -> None: + super().__init__() + self._pg: dist.ProcessGroup = pg + + def forward( + self, + input: torch.Tensor, + input_splits: torch.Tensor, + output_splits: torch.Tensor, + ) -> TensorAllToAllValuesAwaitable: + """ + Sends tensor to relevant `ProcessGroup` ranks. + + Args: + input (torch.Tensor): `torch.Tensor` of values to distribute. + input_splits (torch.Tensor): tensor containing number of rows + to be sent to each rank. len(input_splits) must equal self._pg.size() + output_splits (torch.Tensor): tensor containing number of rows + to be received from each rank. len(output_splits) must equal self._pg.size() + + Returns: `TensorAllToAllValuesAwaitable` + """ + with torch.no_grad(): + return TensorAllToAllValuesAwaitable( + pg=self._pg, + input=input, + input_splits=input_splits, + output_splits=output_splits, + device=input.device, + ) + + +class TensorAllToAll(nn.Module): + """ + Redistributes a 1D tensor to a `ProcessGroup` according to splits. + + Implementation utilizes AlltoAll collective as part of torch.distributed. + + The first collective call in `TensorAllToAllSplitsAwaitable` will transmit + splits to allocate correct space for the tensor values. The following collective + calls in `TensorAllToAllValuesAwaitable` will transmit the actual + tensor values asynchronously. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + + Example:: + tensor_A2A = TensorAllToAll(pg) + splits = torch.Tensor([1,1]) on rank0 and rank1 + awaitable = tensor_A2A(rank0_input, splits) + + where: + rank0_input is torch.Tensor holding + [ + [V1, V2, V3], + [V4, V5, V6], + ] + + rank1_input is torch.Tensor holding + [ + [V7, V8, V9], + [V10, V11, V12], + ] + + rank0_output = awaitable.wait().wait() + + # where: + rank0_input is torch.Tensor holding + [ + [V1, V2, V3], + [V7, V8, V9], + ] + + rank1_input is torch.Tensor holding + [ + [V4, V5, V6], + [V10, V11, V12], + ] + """ + + def __init__( + self, + pg: dist.ProcessGroup, + ) -> None: + super().__init__() + self._pg: dist.ProcessGroup = pg + + def forward( + self, + input: torch.Tensor, + splits: torch.Tensor, + ) -> TensorAllToAllSplitsAwaitable: + """ + Sends tensor to relevant `ProcessGroup` ranks. + + The first wait will get the splits for the provided tensors and issue + tensors AlltoAll. The second wait will get the tensors. + + Args: + input (torch.Tensor): `torch.Tensor` of values to distribute. + + Returns: + Awaitable[TensorAllToAllValuesAwaitable]: awaitable of a `TensorAllToAllValuesAwaitable`. + """ + with torch.no_grad(): + temp = TensorAllToAllSplitsAwaitable( + pg=self._pg, + input=input, + splits=splits, + device=input.device, + ) + return temp diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 72f293278..2f006758e 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -5,41 +5,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict import copy -import functools +import logging +import warnings from collections import defaultdict, deque, OrderedDict -from dataclasses import dataclass, field +from itertools import accumulate from typing import ( Any, cast, Dict, - Iterator, List, MutableMapping, Optional, - Set, Tuple, Type, - Union, + Union as TypeUnion, ) import torch -from torch import nn -from torch.nn.modules.module import _IncompatibleKeys +from tensordict import TensorDict +from torch import distributed as dist, nn +from torch.autograd.profiler import record_function +from torch.distributed._shard.sharding_spec import EnumerableShardingSpec +from torch.distributed._tensor import DTensor +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingInfo, - SparseFeaturesIndicesAwaitable, - SparseFeaturesListAwaitable, - SparseFeaturesListIndicesAwaitable, + KJTListSplitsAwaitable, ) from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, + KJTList, + ShardedEmbeddingModule, ShardingType, - SparseFeatures, - SparseFeaturesList, ) from torchrec.distributed.sharding.cw_sequence_sharding import ( CwSequenceEmbeddingSharding, @@ -55,21 +58,27 @@ from torchrec.distributed.sharding.tw_sequence_sharding import ( TwSequenceEmbeddingSharding, ) +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( Awaitable, + EmbeddingEvent, + EmbeddingModuleShardingPlan, LazyAwaitable, Multistreamable, ParameterSharding, QuantizedCommCodecs, - ShardedModule, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardMetadata, ) from torchrec.distributed.utils import ( - append_prefix, - filter_state_dict, + add_params_from_parameter_sharding, + convert_to_fbgemm_types, + create_global_tensor_shape_stride_from_metadata, + maybe_annotate_embedding_event, merge_fused_params, + none_throws, optimizer_type_to_emb_opt_type, ) from torchrec.modules.embedding_configs import ( @@ -81,9 +90,11 @@ EmbeddingCollection, EmbeddingCollectionInterface, ) -from torchrec.optim.fused import FusedOptimizerModule +from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext +from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -92,60 +103,62 @@ pass -def create_embedding_sharding( - sharding_type: str, - sharding_infos: List[EmbeddingShardingInfo], - env: ShardingEnv, - device: Optional[torch.device] = None, - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, -) -> EmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor -]: - if sharding_type == ShardingType.TABLE_WISE.value: - return TwSequenceEmbeddingSharding( - sharding_infos=sharding_infos, - env=env, - device=device, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.ROW_WISE.value: - return RwSequenceEmbeddingSharding( - sharding_infos=sharding_infos, - env=env, - device=device, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.DATA_PARALLEL.value: - return DpSequenceEmbeddingSharding( - sharding_infos=sharding_infos, - env=env, - device=device, - ) - elif sharding_type == ShardingType.COLUMN_WISE.value: - return CwSequenceEmbeddingSharding( - sharding_infos=sharding_infos, - env=env, - device=device, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) +logger: logging.Logger = logging.getLogger(__name__) + + +EC_INDEX_DEDUP: bool = False + + +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> TypeUnion[str, Tuple[str, ...]]: + """ + Returns list of device type per shard if table is sharded across different device type + else reutrns single device type for the table parameter + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] else: - raise ValueError(f"Sharding not supported {sharding_type}") + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list + + +def set_ec_index_dedup(val: bool) -> None: + warnings.warn( + "Please set use_index_dedup in EmbeddingCollectionSharder during __init__ instead", + DeprecationWarning, + stacklevel=2, + ) + global EC_INDEX_DEDUP + EC_INDEX_DEDUP = val -def create_sharding_infos_by_sharding( +def get_ec_index_dedup() -> bool: + global EC_INDEX_DEDUP + return EC_INDEX_DEDUP + + +def create_sharding_infos_by_sharding_device_group( module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], fused_params: Optional[Dict[str, Any]], -) -> Dict[str, List[EmbeddingShardingInfo]]: +) -> Dict[Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: if fused_params is None: fused_params = {} - sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + sharding_type_device_group_to_sharding_infos: Dict[ + Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] + ] = {} # state_dict returns parameter.Tensor, which loses parameter level attributes parameter_by_name = dict(module.named_parameters()) # QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there @@ -170,18 +183,40 @@ def create_sharding_infos_by_sharding( assert param_name in parameter_by_name or param_name in state_dict param = parameter_by_name.get(param_name, state_dict[param_name]) - if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos: - sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = [] - - optimizer_params = getattr(param, "_optimizer_kwargs", {}) - optimizer_class = getattr(param, "_optimizer_class", None) + device_group: TypeUnion[str, Tuple[str, ...]] = ( + get_device_from_parameter_sharding(parameter_sharding) + ) + if ( + parameter_sharding.sharding_type, + device_group, + ) not in sharding_type_device_group_to_sharding_infos: + sharding_type_device_group_to_sharding_infos[ + (parameter_sharding.sharding_type, device_group) + ] = [] + + optimizer_params = getattr(param, "_optimizer_kwargs", [{}]) + optimizer_classes = getattr(param, "_optimizer_classes", [None]) + + assert ( + len(optimizer_classes) == 1 and len(optimizer_params) == 1 + ), f"Only support 1 optimizer, given {len(optimizer_classes)}" + + optimizer_class = optimizer_classes[0] + optimizer_params = optimizer_params[0] if optimizer_class: optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type( optimizer_class ) - fused_params = merge_fused_params(fused_params, optimizer_params) - sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( + per_table_fused_params = merge_fused_params(fused_params, optimizer_params) + per_table_fused_params = add_params_from_parameter_sharding( + per_table_fused_params, parameter_sharding + ) + per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + + sharding_type_device_group_to_sharding_infos[ + (parameter_sharding.sharding_type, device_group) + ].append( ( EmbeddingShardingInfo( embedding_config=EmbeddingTableConfig( @@ -199,64 +234,76 @@ def create_sharding_infos_by_sharding( ), param_sharding=parameter_sharding, param=param, - fused_params=fused_params, + fused_params=per_table_fused_params, ) ) ) - return sharding_type_to_sharding_infos - - -def _construct_jagged_tensors( - embeddings: torch.Tensor, - features: KeyedJaggedTensor, - embedding_names: List[str], - need_indices: bool = False, - features_to_permute_indices: Optional[Dict[str, List[int]]] = None, -) -> Dict[str, JaggedTensor]: - ret: Dict[str, JaggedTensor] = {} - stride = features.stride() - length_per_key = features.length_per_key() - values = features.values() - - lengths = features.lengths().view(-1, stride) - lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) - embeddings_list = torch.split(embeddings, length_per_key, dim=0) - values_list = torch.split(values, length_per_key) if need_indices else None - - key_indices = defaultdict(list) - for i, key in enumerate(embedding_names): - key_indices[key].append(i) - for key, indices in key_indices.items(): - # combines outputs in correct order for CW sharding - indices = ( - _permute_indices(indices, features_to_permute_indices[key]) - if features_to_permute_indices and key in features_to_permute_indices - else indices - ) - ret[key] = JaggedTensor( - lengths=lengths_tuple[indices[0]], - values=embeddings_list[indices[0]] - if len(indices) == 1 - else torch.cat([embeddings_list[i] for i in indices], dim=1), - weights=values_list[indices[0]] if values_list else None, - ) - return ret + return sharding_type_device_group_to_sharding_infos + + +def pad_vbe_kjt_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor: + max_stride = max(features.stride_per_key()) + new_lengths = torch.zeros( + max_stride * len(features.keys()), + device=features.device(), + dtype=features.lengths().dtype, + ) + cum_stride = 0 + for i, stride in enumerate(features.stride_per_key()): + new_lengths[i * max_stride : i * max_stride + stride] = features.lengths()[ + cum_stride : cum_stride + stride + ] + cum_stride += stride + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + lengths=new_lengths, + stride=max_stride, + length_per_key=features.length_per_key(), + offset_per_key=features.offset_per_key(), + ) -def _permute_indices(indices: List[int], permute: List[int]) -> List[int]: - permuted_indices = [0] * len(indices) - for i, permuted_index in enumerate(permute): - permuted_indices[i] = indices[permuted_index] - return permuted_indices +def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + TODO: remove and import from `jagged_tensor.py` once packaging issue is resolved + """ + return ( + tensor.pin_memory().to(device=device, non_blocking=True) + if device.type == "cuda" and tensor.device.type == "cpu" + else tensor.to(device=device, non_blocking=True) + ) -@dataclass class EmbeddingCollectionContext(Multistreamable): - sharding_contexts: List[SequenceShardingContext] = field(default_factory=list) + # Torch Dynamo does not support default_factory=list: + # https://github.com/pytorch/pytorch/issues/120108 + # TODO(ivankobzarev): Make this a dataclass once supported + + def __init__( + self, + sharding_contexts: Optional[List[SequenceShardingContext]] = None, + input_features: Optional[List[KeyedJaggedTensor]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None, + ) -> None: + super().__init__() + self.sharding_contexts: List[SequenceShardingContext] = sharding_contexts or [] + self.input_features: List[KeyedJaggedTensor] = input_features or [] + self.reverse_indices: List[torch.Tensor] = reverse_indices or [] + self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or [] + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {} - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: for ctx in self.sharding_contexts: ctx.record_stream(stream) + for f in self.input_features: + f.record_stream(stream) + for r in self.reverse_indices: + r.record_stream(stream) + for s in self.seq_vbe_ctx: + s.record_stream(stream) class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]): @@ -265,8 +312,11 @@ def __init__( awaitables_per_sharding: List[Awaitable[torch.Tensor]], features_per_sharding: List[KeyedJaggedTensor], embedding_names_per_sharding: List[List[str]], + ctx: EmbeddingCollectionContext, need_indices: bool = False, features_to_permute_indices: Optional[Dict[str, List[int]]] = None, + module_fqn: Optional[str] = None, + sharding_types: Optional[List[str]] = None, ) -> None: super().__init__() self._awaitables_per_sharding = awaitables_per_sharding @@ -274,29 +324,58 @@ def __init__( self._need_indices = need_indices self._features_to_permute_indices = features_to_permute_indices self._embedding_names_per_sharding = embedding_names_per_sharding + self._ctx = ctx + self._module_fqn = module_fqn + self._sharding_types = sharding_types def _wait_impl(self) -> Dict[str, JaggedTensor]: jt_dict: Dict[str, JaggedTensor] = {} - for w, f, e in zip( - self._awaitables_per_sharding, - self._features_per_sharding, - self._embedding_names_per_sharding, + for i, (w, f, e) in enumerate( + zip( + self._awaitables_per_sharding, + self._features_per_sharding, + self._embedding_names_per_sharding, + ) ): + original_features = ( + None + if i >= len(self._ctx.input_features) + else self._ctx.input_features[i] + ) + reverse_indices = ( + None + if i >= len(self._ctx.reverse_indices) + else self._ctx.reverse_indices[i] + ) + seq_vbe_ctx = ( + None if i >= len(self._ctx.seq_vbe_ctx) else self._ctx.seq_vbe_ctx[i] + ) + + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST_WAIT, + self._module_fqn, + self._sharding_types[i] if self._sharding_types else None, + ): + embeddings = w.wait() + jt_dict.update( - _construct_jagged_tensors( - embeddings=w.wait(), + construct_jagged_tensors( + embeddings=embeddings, features=f, embedding_names=e, need_indices=self._need_indices, features_to_permute_indices=self._features_to_permute_indices, + original_features=original_features, + reverse_indices=reverse_indices, + seq_vbe_ctx=seq_vbe_ctx, ) ) return jt_dict class ShardedEmbeddingCollection( - ShardedModule[ - SparseFeaturesList, + ShardedEmbeddingModule[ + KJTList, List[torch.Tensor], Dict[str, JaggedTensor], EmbeddingCollectionContext, @@ -317,43 +396,67 @@ def __init__( fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, + use_index_dedup: bool = False, + module_fqn: Optional[str] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( + self._module_fqn = module_fqn + self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs() + self._table_names: List[str] = [ + config.name for config in self._embedding_configs + ] + self._table_name_to_config: Dict[str, EmbeddingConfig] = { + config.name: config for config in self._embedding_configs + } + self.module_sharding_plan: EmbeddingModuleShardingPlan = cast( + EmbeddingModuleShardingPlan, + { + table_name: parameter_sharding + for table_name, parameter_sharding in table_name_to_parameter_sharding.items() + if table_name in self._table_names + }, + ) + self._env = env + # output parameters as DTensor in state dict + self._output_dtensor: bool = env.output_dtensor + # TODO get rid of get_ec_index_dedup global flag + self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup() + sharding_type_to_sharding_infos = self.create_grouped_sharding_infos( module, table_name_to_parameter_sharding, fused_params, ) - self._variable_batch_size = variable_batch_size + + self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys()) + self._sharding_type_to_sharding: Dict[ str, EmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ], ] = { - sharding_type: create_embedding_sharding( + sharding_type: self.create_embedding_sharding( sharding_type=sharding_type, sharding_infos=embedding_confings, env=env, device=device, qcomm_codecs_registry=self.qcomm_codecs_registry, - variable_batch_size=self._variable_batch_size, ) for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items() } self._device = device - self._input_dists: nn.ModuleList = nn.ModuleList() - self._lookups: nn.ModuleList = nn.ModuleList() + self._input_dists: List[nn.Module] = [] + self._lookups: List[nn.Module] = [] self._create_lookups() - self._output_dists: nn.ModuleList = nn.ModuleList() + self._output_dists: List[nn.Module] = [] self._create_output_dist() self._feature_splits: List[int] = [] self._features_order: List[int] = [] self._has_uninitialized_input_dist: bool = True + logger.info(f"EC index dedup enabled: {self._use_index_dedup}.") # Get all fused optimizers and combine them. optims = [] @@ -361,7 +464,9 @@ def __init__( for _, m in lookup.named_modules(): if isinstance(m, FusedOptimizerModule): # modify param keys to match EmbeddingCollection - params: MutableMapping[str, Union[torch.Tensor, ShardedTensor]] = {} + params: MutableMapping[ + str, TypeUnion[torch.Tensor, ShardedTensor] + ] = {} for param_key, weight in m.fused_optimizer.params.items(): params["embeddings." + param_key] = weight m.fused_optimizer.params = params @@ -383,6 +488,508 @@ def __init__( module.embedding_configs(), table_name_to_parameter_sharding ) self._need_indices: bool = module.need_indices() + self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None + + for index, (sharding, lookup) in enumerate( + zip( + self._sharding_type_to_sharding.values(), + self._lookups, + ) + ): + # TODO: can move this into DpPooledEmbeddingSharding once all modules are composable + if isinstance(sharding, DpSequenceEmbeddingSharding): + self._lookups[index] = DistributedDataParallel( + module=lookup, + device_ids=( + [self._device] + if self._device is not None and self._device.type == "cuda" + else None + ), + process_group=env.process_group, + gradient_as_bucket_view=True, + broadcast_buffers=True, + static_graph=True, + ) + self._initialize_torch_state() + + if module.device != torch.device("meta"): + self.load_state_dict(module.state_dict()) + + @classmethod + def create_grouped_sharding_infos( + cls, + module: EmbeddingCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + fused_params: Optional[Dict[str, Any]], + ) -> Dict[str, List[EmbeddingShardingInfo]]: + """ + convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to + EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters + """ + if fused_params is None: + fused_params = {} + + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + # state_dict returns parameter.Tensor, which loses parameter level attributes + parameter_by_name = dict(module.named_parameters()) + # QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there + state_dict = module.state_dict() + + for ( + config, + embedding_names, + ) in zip(module.embedding_configs(), module.embedding_names_by_table()): + table_name = config.name + assert table_name in table_name_to_parameter_sharding + + parameter_sharding = table_name_to_parameter_sharding[table_name] + if parameter_sharding.compute_kernel not in [ + kernel.value for kernel in EmbeddingComputeKernel + ]: + raise ValueError( + f"Compute kernel not supported {parameter_sharding.compute_kernel}" + ) + + param_name = "embeddings." + config.name + ".weight" + assert param_name in parameter_by_name or param_name in state_dict + param = parameter_by_name.get(param_name, state_dict[param_name]) + + if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos: + sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = [] + + optimizer_params = getattr(param, "_optimizer_kwargs", [{}]) + optimizer_classes = getattr(param, "_optimizer_classes", [None]) + + assert ( + len(optimizer_classes) == 1 and len(optimizer_params) == 1 + ), f"Only support 1 optimizer, given {len(optimizer_classes)}" + + optimizer_class = optimizer_classes[0] + optimizer_params = optimizer_params[0] + if optimizer_class: + optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type( + optimizer_class + ) + + per_table_fused_params = merge_fused_params(fused_params, optimizer_params) + per_table_fused_params = add_params_from_parameter_sharding( + per_table_fused_params, parameter_sharding + ) + per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + + sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( + ( + EmbeddingShardingInfo( + embedding_config=EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=copy.deepcopy(config.feature_names), + pooling=PoolingType.NONE, + is_weighted=False, + has_feature_processor=False, + embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + ), + param_sharding=parameter_sharding, + param=param, + fused_params=per_table_fused_params, + ) + ) + ) + return sharding_type_to_sharding_infos + + @classmethod + def create_embedding_sharding( + cls, + sharding_type: str, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> EmbeddingSharding[ + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ]: + """ + This is the main function to generate `EmbeddingSharding` instances based on sharding_type + so that the same sharding_type in one EC would be fused. + """ + if sharding_type == ShardingType.TABLE_WISE.value: + return TwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.ROW_WISE.value: + return RwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + return DpSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return CwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + + @staticmethod + def _pre_state_dict_hook( + self: "ShardedEmbeddingCollection", + prefix: str = "", + keep_vars: bool = False, + ) -> None: + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.flush() + + @staticmethod + def _pre_load_state_dict_hook( + self: "ShardedEmbeddingCollection", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + Modify the destination state_dict for model parallel + to transform from ShardedTensors/DTensors into tensors + """ + for table_name in self._model_parallel_name_to_local_shards.keys(): + key = f"{prefix}embeddings.{table_name}.weight" + # gather model shards from both DTensor and ShardedTensor maps + model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[ + table_name + ] + model_shards_dtensor = self._model_parallel_name_to_shards_wrapper[ + table_name + ] + # If state_dict[key] is already a ShardedTensor, use its local shards + if isinstance(state_dict[key], ShardedTensor): + local_shards = state_dict[key].local_shards() + if len(local_shards) == 0: + state_dict[key] = torch.empty(0) + else: + dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1] + # CW multiple shards are merged + if len(local_shards) > 1: + state_dict[key] = torch.cat( + [s.tensor.view(-1) for s in local_shards], dim=0 + ).view(-1, dim) + else: + state_dict[key] = local_shards[0].tensor.view(-1, dim) + elif isinstance(state_dict[key], DTensor): + shards_wrapper = state_dict[key].to_local() + local_shards = shards_wrapper.local_shards() + dim = shards_wrapper.local_sizes()[0][1] + if len(local_shards) == 0: + state_dict[key] = torch.empty(0) + elif len(local_shards) > 1: + state_dict[key] = torch.cat( + [s.view(-1) for s in local_shards], dim=0 + ).view(-1, dim) + else: + state_dict[key] = local_shards[0].view(-1, dim) + elif isinstance(state_dict[key], torch.Tensor): + local_shards = [] + if model_shards_sharded_tensor: + # splice according to sharded tensor metadata + for shard in model_shards_sharded_tensor: + # Extract shard size and offsets for splicing + shard_size = shard.metadata.shard_sizes + shard_offset = shard.metadata.shard_offsets + + # Prepare tensor by splicing and placing on appropriate device + spliced_tensor = state_dict[key][ + shard_offset[0] : shard_offset[0] + shard_size[0], + shard_offset[1] : shard_offset[1] + shard_size[1], + ] + + # Append spliced tensor into local shards + local_shards.append(spliced_tensor) + elif model_shards_dtensor: + # splice according to dtensor metadata + for tensor, shard_offset in zip( + model_shards_dtensor["local_tensors"], + model_shards_dtensor["local_offsets"], + ): + shard_size = tensor.size() + spliced_tensor = state_dict[key][ + shard_offset[0] : shard_offset[0] + shard_size[0], + shard_offset[1] : shard_offset[1] + shard_size[1], + ] + local_shards.append(spliced_tensor) + state_dict[key] = ( + torch.empty(0) + if not local_shards + else torch.cat(local_shards, dim=0) + ) + else: + raise RuntimeError( + f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" + ) + + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.purge() + + def _initialize_torch_state(self) -> None: # noqa + """ + This provides consistency between this class and the EmbeddingCollection's + nn.Module API calls (state_dict, named_modules, etc) + """ + + self.embeddings: nn.ModuleDict = nn.ModuleDict() + for table_name in self._table_names: + self.embeddings[table_name] = nn.Module() + self._model_parallel_name_to_local_shards = OrderedDict() + self._model_parallel_name_to_shards_wrapper = OrderedDict() + self._model_parallel_name_to_sharded_tensor = OrderedDict() + self._model_parallel_name_to_dtensor = OrderedDict() + _model_parallel_name_to_compute_kernel: Dict[str, str] = {} + for ( + table_name, + parameter_sharding, + ) in self.module_sharding_plan.items(): + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + # Don't need to use sharded/distributed state tensor for DATA_PARALLEL + # because each rank has a full copy of the table in DATA_PARALLEL + continue + _model_parallel_name_to_compute_kernel[table_name] = ( + parameter_sharding.compute_kernel + ) + if ( + parameter_sharding.compute_kernel + == EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value + ): + # Skip state_dict handling for CUSTOMIZED_KERNEL, this should be implemented + # in child class for the CUSTOMIZED_KERNEL + continue + self._model_parallel_name_to_local_shards[table_name] = [] + self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict( + [("local_tensors", []), ("local_offsets", [])] + ) + + self._name_to_table_size = {} + for table in self._embedding_configs: + self._name_to_table_size[table.name] = ( + table.num_embeddings, + table.embedding_dim, + ) + + for sharding_type, lookup in zip( + self._sharding_type_to_sharding.keys(), self._lookups + ): + if sharding_type == ShardingType.DATA_PARALLEL.value: + # unwrap DDP + lookup = lookup.module + else: + # save local_shards for transforming MP params to shardedTensor + for key, v in lookup.state_dict().items(): + table_name = key[: -len(".weight")] + if ( + _model_parallel_name_to_compute_kernel[table_name] + == EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value + ): + continue + if isinstance(v, DTensor): + shards_wrapper = self._model_parallel_name_to_shards_wrapper[ + table_name + ] + local_shards_wrapper = v._local_tensor + shards_wrapper["local_tensors"].extend( + # pyre-ignore[16] + local_shards_wrapper.local_shards() + ) + shards_wrapper["local_offsets"].extend( + # pyre-ignore[16] + local_shards_wrapper.local_offsets() + ) + shards_wrapper["global_size"] = v.size() + shards_wrapper["global_stride"] = v.stride() + shards_wrapper["placements"] = v.placements + elif isinstance(v, ShardedTensor): + self._model_parallel_name_to_local_shards[table_name].extend( + v.local_shards() + ) + for ( + table_name, + tbe_slice, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `named_parameters_by_table`. + ) in lookup.named_parameters_by_table(): + self.embeddings[table_name].register_parameter("weight", tbe_slice) + for table_name in self._model_parallel_name_to_local_shards.keys(): + local_shards = self._model_parallel_name_to_local_shards[table_name] + shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name] + + # for shards that don't exist on this rank, register with empty tensor + if not hasattr(self.embeddings[table_name], "weight"): + self.embeddings[table_name].register_parameter( + "weight", nn.Parameter(torch.empty(0)) + ) + if ( + _model_parallel_name_to_compute_kernel[table_name] + != EmbeddingComputeKernel.DENSE.value + ): + self.embeddings[table_name].weight._in_backward_optimizers = [ + EmptyFusedOptimizer() + ] + + if self._output_dtensor: + assert _model_parallel_name_to_compute_kernel[table_name] not in { + EmbeddingComputeKernel.KEY_VALUE.value + } + if shards_wrapper_map["local_tensors"]: + self._model_parallel_name_to_dtensor[table_name] = ( + DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=shards_wrapper_map["local_tensors"], + local_offsets=shards_wrapper_map["local_offsets"], + ), + device_mesh=self._env.device_mesh, + placements=shards_wrapper_map["placements"], + shape=shards_wrapper_map["global_size"], + stride=shards_wrapper_map["global_stride"], + run_check=False, + ) + ) + else: + shape, stride = create_global_tensor_shape_stride_from_metadata( + none_throws(self.module_sharding_plan[table_name]), + ( + self._env.node_group_size + if isinstance(self._env, ShardingEnv2D) + else get_local_size(self._env.world_size) + ), + ) + # empty shard case + self._model_parallel_name_to_dtensor[table_name] = ( + DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[], + local_offsets=[], + ), + device_mesh=self._env.device_mesh, + run_check=False, + shape=shape, + stride=stride, + ) + ) + else: + # created ShardedTensors once in init, use in post_state_dict_hook + # note: at this point kvstore backed tensors don't own valid snapshots, so no read + # access is allowed on them. + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), + ) + ) + + def extract_sharded_kvtensors( + module: ShardedEmbeddingCollection, + ) -> OrderedDict[str, ShardedTensor]: + # retrieve all kvstore backed tensors + ret = OrderedDict() + for ( + table_name, + sharded_t, + ) in module._model_parallel_name_to_sharded_tensor.items(): + if _model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.KEY_VALUE.value + }: + ret[table_name] = sharded_t + return ret + + def post_state_dict_hook( + module: ShardedEmbeddingCollection, + destination: Dict[str, torch.Tensor], + prefix: str, + _local_metadata: Dict[str, Any], + ) -> None: + # Adjust dense MP + for ( + table_name, + sharded_t, + ) in module._model_parallel_name_to_sharded_tensor.items(): + destination_key = f"{prefix}embeddings.{table_name}.weight" + destination[destination_key] = sharded_t + for ( + table_name, + d_tensor, + ) in module._model_parallel_name_to_dtensor.items(): + destination_key = f"{prefix}embeddings.{table_name}.weight" + destination[destination_key] = d_tensor + + # kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid + # snapshot for read access. + sharded_kvtensors = extract_sharded_kvtensors(module) + if len(sharded_kvtensors) == 0: + return + + sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + for lookup, sharding_type in zip( + module._lookups, module._sharding_type_to_sharding.keys() + ): + if sharding_type != ShardingType.DATA_PARALLEL.value: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + for key, v in lookup.get_named_split_embedding_weights_snapshot(): + assert key in sharded_kvtensors_copy + sharded_kvtensors_copy[key].local_shards()[0].tensor = v + for ( + table_name, + sharded_kvtensor, + ) in sharded_kvtensors_copy.items(): + destination_key = f"{prefix}embeddings.{table_name}.weight" + destination[destination_key] = sharded_kvtensor + + self.register_state_dict_pre_hook(self._pre_state_dict_hook) + self._register_state_dict_hook(post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self._device and self._device.type == "meta": + return + # Initialize embedding weights with init_fn + for table_config in self._embedding_configs: + if self.module_sharding_plan[table_config.name].compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE.value, + }: + continue + assert table_config.init_fn is not None + param = self.embeddings[f"{table_config.name}"].weight + # pyre-ignore + table_config.init_fn(param) + + sharding_type = self.module_sharding_plan[table_config.name].sharding_type + if sharding_type == ShardingType.DATA_PARALLEL.value: + pg = self._env.process_group + with torch.no_grad(): + dist.broadcast(param.data, src=0, group=pg) def _generate_permute_indices_per_feature( self, @@ -434,16 +1041,73 @@ def _generate_permute_indices_per_feature( else: self._features_to_permute_indices[feature_name] = permute_indices + def _create_hash_size_info( + self, + feature_names: List[str], + ctx: Optional[EmbeddingCollectionContext] = None, + ) -> None: + feature_index = 0 + table_to_unpruned_size_mapping: Optional[Dict[str, int]] = None + if ( + ctx is not None + and getattr(ctx, "table_name_to_unpruned_hash_sizes", None) + and len(ctx.table_name_to_unpruned_hash_sizes) > 0 + ): + table_to_unpruned_size_mapping = ctx.table_name_to_unpruned_hash_sizes + for i, sharding in enumerate(self._sharding_type_to_sharding.values()): + feature_hash_size: List[int] = [] + feature_hash_size_lengths: List[int] = [] + for table in sharding.embedding_tables(): + table_hash_size = [0] * table.num_features() + if table_to_unpruned_size_mapping and table.name: + table_hash_size[-1] = table_to_unpruned_size_mapping[table.name] + else: + table_hash_size[-1] = table.num_embeddings + feature_hash_size.extend(table_hash_size) + + table_hash_size = [0] * table.num_features() + table_hash_size[0] = table.num_features() + feature_hash_size_lengths.extend(table_hash_size) + + # Sanity check for feature orders + for f in range(table.num_features()): + assert feature_names[feature_index + f] == table.feature_names[f] + feature_index += table.num_features() + + feature_hash_size_cumsum: List[int] = [0] + list( + accumulate(feature_hash_size) + ) + feature_hash_size_offset: List[int] = [0] + list( + accumulate(feature_hash_size_lengths) + ) + + # Register buffers for this shard + self.register_buffer( + f"_hash_size_cumsum_tensor_{i}", + torch.tensor( + feature_hash_size_cumsum, device=self._device, dtype=torch.int64 + ), + persistent=False, + ) + self.register_buffer( + f"_hash_size_offset_tensor_{i}", + torch.tensor( + feature_hash_size_offset, device=self._device, dtype=torch.int64 + ), + persistent=False, + ) + def _create_input_dist( self, input_feature_names: List[str], + ctx: Optional[EmbeddingCollectionContext] = None, ) -> None: feature_names: List[str] = [] self._feature_splits: List[int] = [] for sharding in self._sharding_type_to_sharding.values(): self._input_dists.append(sharding.create_input_dist()) - feature_names.extend(sharding.id_list_feature_names()) - self._feature_splits.append(len(sharding.id_list_feature_names())) + feature_names.extend(sharding.feature_names()) + self._feature_splits.append(len(sharding.feature_names())) self._features_order: List[int] = [] for f in feature_names: self._features_order.append(input_feature_names.index(f)) @@ -455,8 +1119,12 @@ def _create_input_dist( self.register_buffer( "_features_order_tensor", torch.tensor(self._features_order, device=self._device, dtype=torch.int32), + persistent=False, ) + if self._use_index_dedup: + self._create_hash_size_info(feature_names, ctx) + def _create_lookups(self) -> None: for sharding in self._sharding_type_to_sharding.values(): self._lookups.append(sharding.create_lookup()) @@ -467,150 +1135,184 @@ def _create_output_dist( for sharding in self._sharding_type_to_sharding.values(): self._output_dists.append(sharding.create_output_dist()) - def lengths_dist( - self, ctx: EmbeddingCollectionContext, features: KeyedJaggedTensor - ) -> SparseFeaturesListIndicesAwaitable: - """ - Creates lengths all2all awaitables. - """ - if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) - self._has_uninitialized_input_dist = False - with torch.no_grad(): - if self._features_order: - features = features.permute( - self._features_order, - # pyre-ignore [6] - self._features_order_tensor, + def _dedup_indices( + self, + ctx: EmbeddingCollectionContext, + input_feature_splits: List[KeyedJaggedTensor], + ) -> List[KeyedJaggedTensor]: + with record_function("## dedup_ec_indices ##"): + features_by_shards = [] + for i, input_feature in enumerate(input_feature_splits): + hash_size_cumsum = self.get_buffer(f"_hash_size_cumsum_tensor_{i}") + hash_size_offset = self.get_buffer(f"_hash_size_offset_tensor_{i}") + ( + lengths, + offsets, + unique_indices, + reverse_indices, + ) = torch.ops.fbgemm.jagged_unique_indices( + hash_size_cumsum, + hash_size_offset, + input_feature.offsets().to(torch.int64), + input_feature.values().to(torch.int64), + ) + dedup_features = KeyedJaggedTensor( + keys=input_feature.keys(), + lengths=lengths, + offsets=offsets, + values=unique_indices, ) - features_by_shards = features.split( - self._feature_splits, - ) - - # Callback to save input splits and output splits in sharding context which - # will be reused in sequence embedding all2all. - def _save_input_output_splits_to_context( - module: nn.Module, - features: KeyedJaggedTensor, - ctx: EmbeddingCollectionContext, - indices_awaitable: Awaitable[SparseFeatures], - ) -> Awaitable[SparseFeatures]: - with torch.no_grad(): - input_splits = [] - output_splits = [] - if isinstance(indices_awaitable, SparseFeaturesIndicesAwaitable): - input_splits = ( - # pyre-fixme[16]: `Optional` has no attribute - # `_in_lengths_per_worker`. - indices_awaitable._id_list_features_awaitable._in_lengths_per_worker - ) - output_splits = ( - # pyre-fixme[16]: `Optional` has no attribute - # `_out_lengths_per_worker`. - indices_awaitable._id_list_features_awaitable._out_lengths_per_worker - ) - ctx.sharding_contexts.append( - SequenceShardingContext( - features_before_input_dist=features, - input_splits=input_splits, - output_splits=output_splits, - unbucketize_permute_tensor=module.unbucketize_permute_tensor - if isinstance(module, RwSparseFeaturesDist) - else None, - ) - ) - return indices_awaitable - awaitables = [] - for module, features in zip(self._input_dists, features_by_shards): - tensor_awaitable = module( - SparseFeatures( - id_list_features=features, - id_score_list_features=None, + ctx.input_features.append(input_feature) + ctx.reverse_indices.append(reverse_indices) + features_by_shards.append(dedup_features) + + return features_by_shards + + def _create_inverse_indices_permute_per_sharding( + self, inverse_indices: Tuple[List[str], torch.Tensor] + ) -> None: + if ( + len(self._embedding_names_per_sharding) == 1 + and self._embedding_names_per_sharding[0] == inverse_indices[0] + ): + return + index_per_name = {name: i for i, name in enumerate(inverse_indices[0])} + permute_per_sharding = [] + for emb_names in self._embedding_names_per_sharding: + permute = _pin_and_move( + torch.tensor( + [index_per_name[name.split("@")[0]] for name in emb_names] + ), + inverse_indices[1].device, + ) + permute_per_sharding.append(permute) + self._inverse_indices_permute_per_sharding = permute_per_sharding + + def _compute_sequence_vbe_context( + self, + ctx: EmbeddingCollectionContext, + unpadded_features: KeyedJaggedTensor, + ) -> None: + assert ( + unpadded_features.inverse_indices_or_none() is not None + ), "inverse indices must be provided from KJT if using variable batch size per feature." + + inverse_indices = unpadded_features.inverse_indices() + stride = inverse_indices[1].numel() // len(inverse_indices[0]) + if self._inverse_indices_permute_per_sharding is None: + self._create_inverse_indices_permute_per_sharding(inverse_indices) + + if self._features_order: + unpadded_features = unpadded_features.permute( + self._features_order, + # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` but + # got `TypeUnion[Module, Tensor]`. + self._features_order_tensor, + ) + + features_by_sharding = unpadded_features.split(self._feature_splits) + for i, feature in enumerate(features_by_sharding): + if self._inverse_indices_permute_per_sharding is not None: + permute = self._inverse_indices_permute_per_sharding[i] + permuted_indices = torch.index_select(inverse_indices[1], 0, permute) + else: + permuted_indices = inverse_indices[1] + stride_per_key = _pin_and_move( + torch.tensor(feature.stride_per_key()), feature.device() + ) + offsets = _to_offsets(stride_per_key)[:-1].unsqueeze(-1) + recat = (permuted_indices + offsets).flatten().int() + + if self._need_indices: + reindexed_lengths, reindexed_values, _ = ( + torch.ops.fbgemm.permute_1D_sparse_data( + recat, + feature.lengths(), + feature.values(), ) ) - tensor_awaitable.callbacks.append( - functools.partial( - _save_input_output_splits_to_context, - module, - features, - ctx, - ) + else: + reindexed_lengths = torch.index_select(feature.lengths(), 0, recat) + reindexed_values = None + + reindexed_lengths = reindexed_lengths.view(-1, stride) + reindexed_length_per_key = torch.sum(reindexed_lengths, dim=1).tolist() + + ctx.seq_vbe_ctx.append( + SequenceVBEContext( + recat=recat, + unpadded_lengths=feature.lengths(), + reindexed_lengths=reindexed_lengths, + reindexed_length_per_key=reindexed_length_per_key, + reindexed_values=reindexed_values, ) - awaitables.append(tensor_awaitable) - return SparseFeaturesListIndicesAwaitable(awaitables) + ) # pyre-ignore [14] def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, - ) -> Awaitable[SparseFeaturesList]: + features: TypeUnion[KeyedJaggedTensor, TensorDict], + ) -> Awaitable[Awaitable[KJTList]]: + need_permute: bool = True + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if self._features_order: + feature_keys = [feature_keys[i] for i in self._features_order] + need_permute = False + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) + self._create_input_dist(input_feature_names=features.keys(), ctx=ctx) self._has_uninitialized_input_dist = False with torch.no_grad(): - if self._features_order: + unpadded_features = None + if features.variable_stride_per_key(): + unpadded_features = features + features = pad_vbe_kjt_lengths(unpadded_features) + + if need_permute and self._features_order: features = features.permute( self._features_order, - # pyre-ignore [6] + # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` + # but got `TypeUnion[Module, Tensor]`. self._features_order_tensor, ) - features_by_shards = features.split( - self._feature_splits, - ) - # save input splits and output splits in sharding context which - # will be reused in sequence embedding all2all + features_by_shards = features.split(self._feature_splits) + if self._use_index_dedup: + features_by_shards = self._dedup_indices(ctx, features_by_shards) + awaitables = [] - for module, features in zip(self._input_dists, features_by_shards): - lengths_awaitable = module( - SparseFeatures( - id_list_features=features, - id_score_list_features=None, - ) - ) - indices_awaitable = lengths_awaitable.wait() # finish lengths all2all - input_splits = [] - output_splits = [] - batch_size_per_rank = [] - sparse_features_recat = None - if isinstance(indices_awaitable, SparseFeaturesIndicesAwaitable): - assert indices_awaitable._id_list_features_awaitable is not None - input_splits = ( - # pyre-fixme[16] - indices_awaitable._id_list_features_awaitable._in_lengths_per_worker - ) - output_splits = ( - # pyre-fixme[16] - indices_awaitable._id_list_features_awaitable._out_lengths_per_worker - ) - batch_size_per_rank = ( - # pyre-fixme[16] - indices_awaitable._id_list_features_awaitable._batch_size_per_rank - ) - # Pass input_dist recat so that we do not need double calculate recat in Sequence embedding all2all to save H2D - sparse_features_recat = ( - # pyre-fixme[16] - indices_awaitable._id_list_features_awaitable._recat - ) + for input_dist, features, sharding_type in zip( + self._input_dists, features_by_shards, self._sharding_type_to_sharding + ): + with maybe_annotate_embedding_event( + EmbeddingEvent.KJT_SPLITS_DIST, self._module_fqn, sharding_type + ): + awaitables.append(input_dist(features)) ctx.sharding_contexts.append( SequenceShardingContext( features_before_input_dist=features, - sparse_features_recat=sparse_features_recat, - input_splits=input_splits, - output_splits=output_splits, - unbucketize_permute_tensor=module.unbucketize_permute_tensor - if isinstance(module, RwSparseFeaturesDist) - else None, - batch_size_per_rank=batch_size_per_rank, + unbucketize_permute_tensor=( + input_dist.unbucketize_permute_tensor + if isinstance(input_dist, RwSparseFeaturesDist) + else None + ), ) ) - awaitables.append(indices_awaitable) - return SparseFeaturesListAwaitable(awaitables) + if unpadded_features is not None: + self._compute_sequence_vbe_context(ctx, unpadded_features) + + return KJTListSplitsAwaitable( + awaitables, + ctx, + self._module_fqn, + list(self._sharding_type_to_sharding.keys()), + ) def compute( - self, ctx: EmbeddingCollectionContext, dist_input: SparseFeaturesList + self, ctx: EmbeddingCollectionContext, dist_input: KJTList ) -> List[torch.Tensor]: ret: List[torch.Tensor] = [] for lookup, features, sharding_ctx, sharding_type in zip( @@ -619,10 +1321,8 @@ def compute( ctx.sharding_contexts, self._sharding_type_to_sharding, ): - sharding_ctx.lengths_after_input_dist = ( - features.id_list_features.lengths().view( - -1, features.id_list_features.stride() - ) + sharding_ctx.lengths_after_input_dist = features.lengths().view( + -1, features.stride() ) embedding_dim = self._embedding_dim_for_sharding_type(sharding_type) ret.append(lookup(features).view(-1, embedding_dim)) @@ -640,6 +1340,8 @@ def output_dist( ): awaitables_per_sharding.append(odist(embeddings, sharding_ctx)) features_before_all2all_per_sharding.append( + # pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but + # got `Optional[KeyedJaggedTensor]`. sharding_ctx.features_before_input_dist ) return EmbeddingCollectionAwaitable( @@ -648,10 +1350,11 @@ def output_dist( embedding_names_per_sharding=self._embedding_names_per_sharding, need_indices=self._need_indices, features_to_permute_indices=self._features_to_permute_indices, + ctx=ctx, ) def compute_and_output_dist( - self, ctx: EmbeddingCollectionContext, input: SparseFeaturesList + self, ctx: EmbeddingCollectionContext, input: KJTList ) -> LazyAwaitable[Dict[str, JaggedTensor]]: awaitables_per_sharding: List[Awaitable[torch.Tensor]] = [] features_before_all2all_per_sharding: List[KeyedJaggedTensor] = [] @@ -662,16 +1365,26 @@ def compute_and_output_dist( ctx.sharding_contexts, self._sharding_type_to_sharding, ): - sharding_ctx.lengths_after_input_dist = ( - features.id_list_features.lengths().view( - -1, features.id_list_features.stride() - ) + sharding_ctx.lengths_after_input_dist = features.lengths().view( + -1, features.stride() ) embedding_dim = self._embedding_dim_for_sharding_type(sharding_type) - awaitables_per_sharding.append( - odist(lookup(features).view(-1, embedding_dim), sharding_ctx) - ) + + with maybe_annotate_embedding_event( + EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type + ): + embs = lookup(features) + + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type + ): + awaitables_per_sharding.append( + odist(embs.view(-1, embedding_dim), sharding_ctx) + ) + features_before_all2all_per_sharding.append( + # pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but + # got `Optional[KeyedJaggedTensor]`. sharding_ctx.features_before_input_dist ) return EmbeddingCollectionAwaitable( @@ -680,6 +1393,9 @@ def compute_and_output_dist( embedding_names_per_sharding=self._embedding_names_per_sharding, need_indices=self._need_indices, features_to_permute_indices=self._features_to_permute_indices, + ctx=ctx, + module_fqn=self._module_fqn, + sharding_types=list(self._sharding_type_to_sharding.keys()), ) def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: @@ -689,86 +1405,6 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: else self._embedding_dim ) - # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for lookup in self._lookups: - lookup.state_dict(destination, prefix + "embeddings.", keep_vars) - return destination - - def named_modules( - self, - memo: Optional[Set[nn.Module]] = None, - prefix: str = "", - remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] - - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for lookup in self._lookups: - yield from lookup.named_parameters( - append_prefix(prefix, "embeddings"), recurse - ) - - def named_buffers( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for lookup in self._lookups: - yield from lookup.named_buffers( - append_prefix(prefix, "embeddings"), recurse - ) - - # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` - # inconsistently. - def load_state_dict( - self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - for lookup in self._lookups: - missing, unexpected = lookup.load_state_dict( - filter_state_dict(state_dict, "embeddings"), - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys - ) - - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - for lookup in self._lookups: - lookup.sparse_grad_parameter_names( - destination, append_prefix(prefix, "embeddings") - ) - return destination - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for lookup, sharding_type in zip( - self._lookups, self._sharding_type_to_sharding.keys() - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - continue - for name, _ in lookup.named_parameters(append_prefix(prefix, "embeddings")): - yield name - @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -778,9 +1414,14 @@ def create_context(self) -> EmbeddingCollectionContext: class EmbeddingCollectionSharder(BaseEmbeddingSharder[EmbeddingCollection]): - """ - This implementation uses non-fused EmbeddingCollection - """ + def __init__( + self, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + use_index_dedup: bool = False, + ) -> None: + super().__init__(fused_params, qcomm_codecs_registry) + self._use_index_dedup = use_index_dedup def shard( self, @@ -788,6 +1429,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingCollection: return ShardedEmbeddingCollection( module, @@ -796,7 +1438,8 @@ def shard( self.fused_params, device, qcomm_codecs_registry=self.qcomm_codecs_registry, - variable_batch_size=self._variable_batch_size, + use_index_dedup=self._use_index_dedup, + module_fqn=module_fqn, ) def shardable_parameters( diff --git a/torchrec/distributed/embedding_dim_bucketer.py b/torchrec/distributed/embedding_dim_bucketer.py new file mode 100644 index 000000000..ef2f58b15 --- /dev/null +++ b/torchrec/distributed/embedding_dim_bucketer.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from enum import Enum, unique +from typing import Dict, List + +from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType + + +@unique +class EmbDimBucketerPolicy(Enum): + """ + Config to specify how to bucketize embedding tables based on dimensions. + + single_bucket: All embedding tables are put into a single bucket. + all_buckets: All the embedding tables with the same dim size are put in the same bucket. + cacheline_buckets: All the embedding tables with the same dim cacheline size are put in the same bucket. + """ + + SINGLE_BUCKET = "single_bucket" + ALL_BUCKETS = "all_buckets" + CACHELINE_BUCKETS = "cacheline_buckets" + + +class EmbDimBucketer: + """ + Buckets embedding dimensions into different groups based on their sizes. This is intended to be leveraged + once planning is done, and at the sharding stage, per rank. + + The rationale to use bucketization is + + - When UVM_CACHING is used: FBGEMM Table Batched Embedding Operator supports a software managed cache for the embeddings placed on UVM (Host memory). + However, the cache uses maximum embedding dim of all the tables batched in the operator as its unit of allocation. This results in waisted HBM memory, + and higher miss rate, hence lower performance. Bucketizing can address this issue, allowing for higher effective cache size and better performace. + + - When all tables are placed on HBM: When tables with widely different embedding dimension are batched together, the register allocation in GPU will + be mainly decided by the table with largest embedding dimension. This can lead to worse performance due to lower number of threads and lower occupancy. + + Note that Column wise sharding also to some extent addresses this problem, but has its own limitations. + + + Generally, we expect the CACHELINE_BUCKETS policy perform better than ALL_BUCKETS, as it addresses the main issues and limits the number of buckets. + + + Args: + embedding_tables (List[ShardedEmbeddingTable]): list of sharded embedding + cfg (EmbDimBucketerPolicy): Bucketing policy + + returns: + emb_dim_buckets (Dict[int, int]): Mapping from embedding dim to bucket id + + + Example: + emb_dim_bucketer = EmbDimBucketer(embedding_tables, EmbDimBucketerPolicy.SINGLE_BUCKET) + ... + bucket = emb_dim_bucketer.get_bucket(embedding_tables[0], embedding_tables[0].data_type) # bucket table 0 is assigned to. + """ + + def __init__( + self, embedding_tables: List[ShardedEmbeddingTable], cfg: EmbDimBucketerPolicy + ) -> None: + self.embedding_dim_buckets: Dict[int, int] = {} + self.num_buckets = 1 + self.cacheline = 128 + if cfg == EmbDimBucketerPolicy.CACHELINE_BUCKETS: + self.emb_dim_buckets: Dict[int, int] = self.cacheline_emb_buckets( + embedding_tables + ) + elif cfg == EmbDimBucketerPolicy.ALL_BUCKETS: + self.emb_dim_buckets: Dict[int, int] = self.all_emb_buckets( + embedding_tables + ) + elif cfg == EmbDimBucketerPolicy.SINGLE_BUCKET: + self.emb_dim_buckets: Dict[int, int] = self.single_emb_bucket( + embedding_tables + ) + else: + AssertionError(f"Invalid bucketization config {cfg}") + + def bucket_count(self) -> int: + return self.num_buckets + + def get_bucket(self, embedding_dim: int, dtype: DataType) -> int: + if self.num_buckets == 1: + return 0 + else: + return self.bucket(embedding_dim, dtype) + + def single_emb_bucket( + self, + embedding_tables: List[ShardedEmbeddingTable], + ) -> Dict[int, int]: + buckets: Dict[int, int] = {} + bucket_id = 0 + + for table in embedding_tables: + dim_in_bytes = self.dim_in_bytes(table.local_cols, table.data_type) + buckets[dim_in_bytes] = bucket_id + + self.num_buckets = 1 + + return buckets + + def all_emb_buckets( + self, + embedding_tables: List[ShardedEmbeddingTable], + ) -> Dict[int, int]: + buckets: Dict[int, int] = {} + bucket_id = -1 + + for table in embedding_tables: + dim_in_bytes = self.dim_in_bytes(table.local_cols, table.data_type) + if dim_in_bytes not in buckets.keys(): + bucket_id += 1 + buckets[dim_in_bytes] = bucket_id + + self.num_buckets = bucket_id + 1 # id starts from 0 + + return buckets + + def cacheline_emb_buckets( + self, + embedding_tables: List[ShardedEmbeddingTable], + ) -> Dict[int, int]: + buckets: Dict[int, int] = {} + cl_buckets: Dict[int, int] = {} + bucket_id = -1 + + for table in embedding_tables: + dim_in_bytes = self.dim_in_bytes(table.local_cols, table.data_type) + cl_dim = dim_in_bytes // self.cacheline + if cl_dim not in cl_buckets.keys(): + bucket_id += 1 + cl_buckets[cl_dim] = bucket_id + + if dim_in_bytes not in buckets.keys(): + buckets[dim_in_bytes] = cl_buckets[cl_dim] + + self.num_buckets = bucket_id + 1 # id starts from 0 + + return buckets + + def bucket(self, dim: int, dtype: DataType) -> int: + return self.emb_dim_buckets[self.dim_in_bytes(dim, dtype)] + + def dim_in_bytes(self, dim: int, dtype: DataType) -> int: + return dim * DATA_TYPE_NUM_BITS[dtype] // 8 diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index fda82b0e8..f3bb60619 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -5,19 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc import logging from collections import defaultdict, OrderedDict -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist +from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( + PartiallyMaterializedTensor, +) from torch import nn +from torch.distributed._tensor import DTensor from torchrec.distributed.embedding_types import ( + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, ) +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -54,6 +62,8 @@ def get_state_dict( nn.ModuleList, List[Union[nn.Module, torch.Tensor]], List[torch.Tensor], + List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]], + List[PartiallyMaterializedTensor], ], pg: Optional[dist.ProcessGroup] = None, destination: Optional[Dict[str, Any]] = None, @@ -70,44 +80,98 @@ def get_state_dict( """ key_to_local_shards: Dict[str, List[Shard]] = defaultdict(list) key_to_global_metadata: Dict[str, ShardedTensorMetadata] = {} + key_to_dtensor_metadata: Dict[str, DTensorMetadata] = {} + # pyre-ignore[33] + key_to_local_tensor_shards: Dict[str, List[Any]] = defaultdict(list) def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: return prefix + f"{embedding_table.name}.weight" for embedding_table, param in zip(embedding_tables, params): key = get_key_from_embedding_table(embedding_table) - assert embedding_table.local_rows == param.size(0) - if embedding_table.compute_kernel not in [ + is_quant = embedding_table.compute_kernel in [ EmbeddingComputeKernel.QUANT, EmbeddingComputeKernel.QUANT_UVM, EmbeddingComputeKernel.QUANT_UVM_CACHING, - ]: - assert embedding_table.local_cols == param.size(1) - # for inference there is no pg, all tensors are local - if embedding_table.global_metadata is not None and pg is not None: + ] + qscale = None + qbias = None + if is_quant: + # For QUANT* param is Tuple[torch.Tensor, Optional[torch.Tensor]] where first argument is the weight table, the second is optional quantization extra information, depending on quantization type. e.g. for fbgemm rowwise quantization this is scale and shift for each row. + assert isinstance(param, tuple) + qscale = param[1] + qbias = param[2] + param = param[0] + + assert embedding_table.local_rows == param.size( # pyre-ignore[16] + 0 + ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] + + if qscale is not None: + assert embedding_table.local_cols == param.size(1) # pyre-ignore[16] + + if embedding_table.dtensor_metadata is not None and pg is not None: + # DTensor path + key_to_dtensor_metadata[key] = embedding_table.dtensor_metadata + key_to_local_tensor_shards[key].append( + [ + param, + embedding_table.local_metadata.shard_offsets, # pyre-ignore[16] + ] + ) + elif embedding_table.global_metadata is not None and pg is not None: # set additional field of sharded tensor based on local tensor properties - embedding_table.global_metadata.tensor_properties.dtype = param.dtype + embedding_table.global_metadata.tensor_properties.dtype = ( + param.dtype # pyre-ignore[16] + ) embedding_table.global_metadata.tensor_properties.requires_grad = ( - param.requires_grad + param.requires_grad # pyre-ignore[16] ) key_to_global_metadata[key] = embedding_table.global_metadata key_to_local_shards[key].append( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + # pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got + # `Optional[ShardMetadata]`. Shard(param, embedding_table.local_metadata) ) else: destination[key] = param + if qscale is not None: + destination[f"{key}_qscale"] = qscale + if qbias is not None: + destination[f"{key}_qbias"] = qbias if pg is not None: # Populate the remaining destinations that have a global metadata for key in key_to_local_shards: global_metadata = key_to_global_metadata[key] - destination[ - key - ] = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=key_to_local_shards[key], - sharded_tensor_metadata=global_metadata, - process_group=pg, + destination[key] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=key_to_local_shards[key], + sharded_tensor_metadata=global_metadata, + process_group=pg, + ) + ) + # DTensor path + for key in key_to_local_tensor_shards: + dtensor_metadata = key_to_dtensor_metadata[key] + destination[key] = DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[ + tensor_shards[0] + for tensor_shards in key_to_local_tensor_shards[key] + ], + local_offsets=[ + tensor_shards[1] + for tensor_shards in key_to_local_tensor_shards[key] + ], + ), + device_mesh=dtensor_metadata.mesh, + placements=dtensor_metadata.placements, + shape=torch.Size(dtensor_metadata.size), # pyre-ignore[6] + stride=dtensor_metadata.stride, + run_check=False, ) - return destination diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 2c40f751e..f868ada1f 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging from abc import ABC from collections import OrderedDict @@ -12,7 +14,20 @@ import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) +from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( + PartiallyMaterializedTensor, +) from torch import nn + +from torch.autograd.function import FunctionCtx +from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torchrec.distributed.batched_embedding_kernel import ( BaseBatchedEmbedding, @@ -21,6 +36,12 @@ BatchedDenseEmbeddingBag, BatchedFusedEmbedding, BatchedFusedEmbeddingBag, + KeyValueEmbedding, + KeyValueEmbeddingBag, +) +from torchrec.distributed.comm_ops import get_gradient_division +from torchrec.distributed.composable.table_batched_embedding_slice import ( + TableBatchedEmbeddingSlice, ) from torchrec.distributed.embedding_kernel import BaseEmbedding from torchrec.distributed.embedding_types import ( @@ -28,21 +49,40 @@ BaseGroupedFeatureProcessor, EmbeddingComputeKernel, GroupedEmbeddingConfig, - SparseFeatures, - SparseFeaturesList, + InputDistOutputs, ) +from torchrec.distributed.fused_params import ( + get_tbes_to_register_from_iterable, + TBEToRegisterMixIn, +) +from torchrec.distributed.global_settings import get_propogate_device from torchrec.distributed.quant_embedding_kernel import ( QuantBatchedEmbedding, QuantBatchedEmbeddingBag, ) -from torchrec.distributed.types import ShardedTensor +from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor logger: logging.Logger = logging.getLogger(__name__) +@torch.fx.wrap +def fx_wrap_tensor_view2d(x: torch.Tensor, dim0: int, dim1: int) -> torch.Tensor: + return x.view(dim0, dim1) + + +@torch.fx.wrap +def dummy_tensor( + sparse_features: KeyedJaggedTensor, dtype: torch.dtype +) -> torch.Tensor: + return torch.empty([0], dtype=dtype, device=sparse_features.device()).view( + sparse_features.stride(), 0 + ) + + def _load_state_dict( emb_modules: "nn.ModuleList", - state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor]]", + state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor, DTensor]]", ) -> Tuple[List[str], List[str]]: missing_keys = [] unexpected_keys = list(state_dict.keys()) @@ -68,6 +108,23 @@ def _load_state_dict( ) dst_local_shard.tensor.detach().copy_(src_local_shard.tensor) + elif isinstance(dst_param, DTensor): + assert isinstance(src_param, DTensor) + assert len( + # pyre-ignore[16] + dst_param.to_local().local_chunks + ) == len(src_param.to_local().local_chunks) + for i, (dst_local_shard, src_local_shard) in enumerate( + zip( + dst_param.to_local().local_shards(), # pyre-ignore[16] + src_param.to_local().local_shards(), + ) + ): + assert ( + dst_param.to_local().local_chunks[i] + == src_param.to_local().local_chunks[i] + ) + dst_local_shard.detach().copy_(src_local_shard) else: assert isinstance(src_param, torch.Tensor) and isinstance( dst_param, torch.Tensor @@ -79,41 +136,60 @@ def _load_state_dict( return missing_keys, unexpected_keys -class GroupedEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]): +@torch.fx.wrap +def embeddings_cat_empty_rank_handle( + embeddings: List[torch.Tensor], + dummy_embs_tensor: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + if len(embeddings) == 0: + # a hack for empty ranks + return dummy_embs_tensor + elif len(embeddings) == 1: + return embeddings[0] + else: + return torch.cat(embeddings, dim=dim) + + +@torch.fx.wrap +def embeddings_cat_empty_rank_handle_inference( + embeddings: List[torch.Tensor], + dim: int = 0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if len(embeddings) == 0: + # return a dummy empty tensor when grouped_configs is empty + dev: Optional[torch.device] = ( + torch.device(device) if device is not None else None + ) + return torch.empty([0], dtype=dtype, device=dev) + elif len(embeddings) == 1: + return embeddings[0] + else: + return torch.cat(embeddings, dim=dim) + + +class GroupedEmbeddingsLookup(BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor]): + """ + Lookup modules for Sequence embeddings (i.e Embeddings) + """ + def __init__( self, grouped_configs: List[GroupedEmbeddingConfig], pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, ) -> None: - def _create_lookup( - config: GroupedEmbeddingConfig, - ) -> BaseEmbedding: - if config.compute_kernel == EmbeddingComputeKernel.DENSE: - return BatchedDenseEmbedding( - config=config, - pg=pg, - device=device, - ) - elif config.compute_kernel == EmbeddingComputeKernel.FUSED: - return BatchedFusedEmbedding( - config=config, - pg=pg, - device=device, - ) - else: - raise ValueError( - f"Compute kernel not supported {config.compute_kernel}" - ) - super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() + self._need_prefetch: bool = False for config in grouped_configs: - self._emb_modules.append(_create_lookup(config)) + self._emb_modules.append(self._create_embedding_kernel(config, pg, device)) - self._id_list_feature_splits: List[int] = [] + self._feature_splits: List[int] = [] for config in grouped_configs: - self._id_list_feature_splits.append(config.num_features()) + self._feature_splits.append(config.num_features()) # return a dummy empty tensor when grouped_configs is empty self.register_buffer( @@ -128,25 +204,87 @@ def _create_lookup( self.grouped_configs = grouped_configs + def _create_embedding_kernel( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + ) -> BaseEmbedding: + for table in config.embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + ): + self._need_prefetch = True + if config.compute_kernel == EmbeddingComputeKernel.DENSE: + return BatchedDenseEmbedding( + config=config, + pg=pg, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.FUSED: + return BatchedFusedEmbedding( + config=config, + pg=pg, + device=device, + ) + elif config.compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE, + }: + return KeyValueEmbedding( + config=config, + pg=pg, + device=device, + ) + else: + raise ValueError(f"Compute kernel not supported {config.compute_kernel}") + + def prefetch( + self, + sparse_features: KeyedJaggedTensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + if not self._need_prefetch: + return + if len(self._emb_modules) > 0: + assert sparse_features is not None + features_by_group = sparse_features.split( + self._feature_splits, + ) + for emb_op, features in zip(self._emb_modules, features_by_group): + if ( + isinstance( + emb_op.emb_module, + ( + SplitTableBatchedEmbeddingBagsCodegen, + SSDTableBatchedEmbeddingBags, + ), + ) + and not emb_op.emb_module.prefetch_pipeline + ): + logging.error( + f"Invalid setting on {type(emb_op.emb_module)} modules. prefetch_pipeline must be set to True.\n" + "If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" + ) + if hasattr(emb_op.emb_module, "prefetch"): + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + ) + def forward( self, - sparse_features: SparseFeatures, + sparse_features: KeyedJaggedTensor, ) -> torch.Tensor: - assert sparse_features.id_list_features is not None embeddings: List[torch.Tensor] = [] - id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits, + features_by_group = sparse_features.split( + self._feature_splits, ) - for emb_op, features in zip(self._emb_modules, id_list_features_by_group): + for emb_op, features in zip(self._emb_modules, features_by_group): embeddings.append(emb_op(features).view(-1)) - if len(embeddings) == 0: - # a hack for empty ranks - return self._dummy_embs_tensor - elif len(embeddings) == 1: - return embeddings[0] - else: - return torch.cat(embeddings) + return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict( @@ -195,55 +333,94 @@ def named_buffers( for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) + def named_parameters_by_table( + self, + ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + """ + Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. + For a single table with multiple shards (i.e CW) these are combined into one table/weight. + Used in composability. + """ + for embedding_kernel in self._emb_modules: + for ( + table_name, + tbe_slice, + ) in embedding_kernel.named_parameters_by_table(): + yield (table_name, tbe_slice) + + def get_named_split_embedding_weights_snapshot( + self, + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for emb_module in self._emb_modules: + if isinstance(emb_module, KeyValueEmbedding): + yield from emb_module.get_named_split_embedding_weights_snapshot() + + def flush(self) -> None: + for emb_module in self._emb_modules: + emb_module.flush() + + def purge(self) -> None: + for emb_module in self._emb_modules: + emb_module.purge() + + +class CommOpGradientScaling(torch.autograd.Function): + @staticmethod + # pyre-ignore + def forward( + ctx: FunctionCtx, input_tensor: torch.Tensor, scale_gradient_factor: int + ) -> torch.Tensor: + # pyre-ignore + ctx.scale_gradient_factor = scale_gradient_factor + return input_tensor + + @staticmethod + # pyre-ignore[14]: `forward` overrides method defined in `Function` inconsistently. + def backward( + ctx: FunctionCtx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + # When gradient division is on, we scale down the gradient by world size + # at alltoall backward for model parallelism. However weights + # is controlled by DDP so it already has gradient division, so we scale + # the gradient back up + # pyre-ignore[16]: `FunctionCtx` has no attribute `scale_gradient_factor` + grad_output.mul_(ctx.scale_gradient_factor) + return grad_output, None + + +class GroupedPooledEmbeddingsLookup( + BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor] +): + """ + Lookup modules for Pooled embeddings (i.e EmbeddingBags) + """ -class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]): def __init__( self, grouped_configs: List[GroupedEmbeddingConfig], - grouped_score_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, pg: Optional[dist.ProcessGroup] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + scale_weight_gradients: bool = True, + sharding_type: Optional[ShardingType] = None, ) -> None: - def _create_lookup( - config: GroupedEmbeddingConfig, - device: Optional[torch.device] = None, - ) -> BaseEmbedding: - if config.compute_kernel == EmbeddingComputeKernel.DENSE: - return BatchedDenseEmbeddingBag( - config=config, - pg=pg, - device=device, - ) - elif config.compute_kernel == EmbeddingComputeKernel.FUSED: - return BatchedFusedEmbeddingBag( - config=config, - pg=pg, - device=device, - ) - else: - raise ValueError( - f"Compute kernel not supported {config.compute_kernel}" - ) - super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device)) - - self._score_emb_modules: nn.ModuleList = nn.ModuleList() - for config in grouped_score_configs: - self._score_emb_modules.append(_create_lookup(config, device)) + self._emb_modules.append( + self._create_embedding_kernel(config, device, pg, sharding_type) + ) - self._id_list_feature_splits: List[int] = [] + self._feature_splits: List[int] = [] for config in grouped_configs: - self._id_list_feature_splits.append(config.num_features()) - self._id_score_list_feature_splits: List[int] = [] - for config in grouped_score_configs: - self._id_score_list_feature_splits.append(config.num_features()) + self._feature_splits.append(config.num_features()) - # return a dummy empty tensor - # when grouped_configs and grouped_score_configs are empty + # return a dummy empty tensor when grouped_configs is empty self.register_buffer( "_dummy_embs_tensor", torch.empty( @@ -255,47 +432,122 @@ def _create_lookup( ) self.grouped_configs = grouped_configs - self.grouped_score_configs = grouped_score_configs self._feature_processor = feature_processor - def forward( - self, - sparse_features: SparseFeatures, - ) -> torch.Tensor: - assert ( - sparse_features.id_list_features is not None - or sparse_features.id_score_list_features is not None + self._world_size: int = dist.get_world_size(pg) + self._scale_gradient_factor: int = ( + self._world_size + if scale_weight_gradients and get_gradient_division() + else 1 ) - embeddings: List[torch.Tensor] = [] + + def _create_embedding_kernel( + self, + config: GroupedEmbeddingConfig, + device: Optional[torch.device], + pg: Optional[dist.ProcessGroup], + sharding_type: Optional[ShardingType], + ) -> BaseEmbedding: + if config.compute_kernel == EmbeddingComputeKernel.DENSE: + return BatchedDenseEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + ) + elif config.compute_kernel == EmbeddingComputeKernel.FUSED: + return BatchedFusedEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + ) + elif config.compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE, + }: + return KeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + ) + else: + raise ValueError(f"Compute kernel not supported {config.compute_kernel}") + + def prefetch( + self, + sparse_features: KeyedJaggedTensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: + for table in config.embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + ): + return True + return False + if len(self._emb_modules) > 0: - assert sparse_features.id_list_features is not None - id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits, + assert sparse_features is not None + features_by_group = sparse_features.split( + self._feature_splits, ) - for config, emb_op, features in zip( - self.grouped_configs, self._emb_modules, id_list_features_by_group - ): - # keep this to avoid break ads code using feature_processor, for ebc - # the has_feature_processor will always be false. Remove this block when - # finishing the migration + for emb_op, features in zip(self._emb_modules, features_by_group): + if not _need_prefetch(emb_op.config): + continue if ( - config.has_feature_processor - and self._feature_processor is not None - and isinstance(self._feature_processor, BaseGroupedFeatureProcessor) + isinstance( + emb_op.emb_module, + ( + SplitTableBatchedEmbeddingBagsCodegen, + SSDTableBatchedEmbeddingBags, + ), + ) + and not emb_op.emb_module.prefetch_pipeline ): - features = self._feature_processor(features) - embeddings.append(emb_op(features)) - if len(self._score_emb_modules) > 0: - assert sparse_features.id_score_list_features is not None - id_score_list_features_by_group = ( - sparse_features.id_score_list_features.split( - self._id_score_list_feature_splits, - ) + logging.error( + f"Invalid setting on {type(emb_op.emb_module)} modules. prefetch_pipeline must be set to True.\n" + "If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" + ) + if hasattr(emb_op.emb_module, "prefetch"): + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + batch_size_per_feature_per_rank=( + features.stride_per_key_per_rank() + if features.variable_stride_per_key() + else None + ), + ) + + def _merge_variable_batch_embeddings( + self, embeddings: List[torch.Tensor], splits: List[List[int]] + ) -> List[torch.Tensor]: + assert len(embeddings) > 1 and len(splits) > 1 + split_embs = [e.split(s) for e, s in zip(embeddings, splits)] + combined_embs = [ + emb + for rank in range(self._world_size) + for n, embs in zip(self._feature_splits, split_embs) + for emb in embs[n * rank : n * rank + n] + ] + return [torch.cat(combined_embs)] + + def forward( + self, + sparse_features: KeyedJaggedTensor, + ) -> torch.Tensor: + embeddings: List[torch.Tensor] = [] + vbe_splits = [] + if len(self._emb_modules) > 0: + assert sparse_features is not None + features_by_group = sparse_features.split( + self._feature_splits, ) for config, emb_op, features in zip( - self.grouped_score_configs, - self._score_emb_modules, - id_score_list_features_by_group, + self.grouped_configs, self._emb_modules, features_by_group ): if ( config.has_feature_processor @@ -303,21 +555,43 @@ def forward( and isinstance(self._feature_processor, BaseGroupedFeatureProcessor) ): features = self._feature_processor(features) + + if config.is_weighted: + features._weights = CommOpGradientScaling.apply( + features._weights, self._scale_gradient_factor + ) + embeddings.append(emb_op(features)) - if len(embeddings) == 0: - # a hack for empty ranks - batch_size: int = ( - sparse_features.id_list_features.stride() - if sparse_features.id_list_features is not None - # pyre-fixme[16]: `Optional` has no attribute `stride`. - else sparse_features.id_score_list_features.stride() + if features.variable_stride_per_key() and len(self._emb_modules) > 1: + stride_per_rank_per_key = list( + zip(*features.stride_per_key_per_rank()) + ) + vbe_splits.append( + [ + stride * dim + for stride_per_rank in stride_per_rank_per_key + for stride, dim in zip( + stride_per_rank, config.embedding_dims() + ) + ] + ) + + if sparse_features.variable_stride_per_key() and len(embeddings) > 1: + embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits) + + dummy_embedding = ( + self._dummy_embs_tensor + if sparse_features.variable_stride_per_key() + else fx_wrap_tensor_view2d( + self._dummy_embs_tensor, sparse_features.stride(), 0 ) - return self._dummy_embs_tensor.view(batch_size, 0) - elif len(embeddings) == 1: - return embeddings[0] - else: - return torch.cat(embeddings, dim=1) + ) + return embeddings_cat_empty_rank_handle( + embeddings, + dummy_embedding, + dim=1, + ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict( @@ -333,8 +607,6 @@ def state_dict( for emb_module in self._emb_modules: emb_module.state_dict(destination, prefix, keep_vars) - for emb_module in self._score_emb_modules: - emb_module.state_dict(destination, prefix, keep_vars) return destination @@ -345,9 +617,8 @@ def load_state_dict( state_dict: "OrderedDict[str, Union[ShardedTensor, torch.Tensor]]", strict: bool = True, ) -> _IncompatibleKeys: - m1, u1 = _load_state_dict(self._emb_modules, state_dict) - m2, u2 = _load_state_dict(self._score_emb_modules, state_dict) - return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2) + m, u = _load_state_dict(self._emb_modules, state_dict) + return _IncompatibleKeys(missing_keys=m, unexpected_keys=u) def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -358,8 +629,6 @@ def named_parameters( ) for emb_module in self._emb_modules: yield from emb_module.named_parameters(prefix, recurse) - for emb_module in self._score_emb_modules: - yield from emb_module.named_parameters(prefix, recurse) def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -370,12 +639,45 @@ def named_buffers( ) for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) - for emb_module in self._score_emb_modules: - yield from emb_module.named_buffers(prefix, recurse) + + def named_parameters_by_table( + self, + ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: + """ + Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. + For a single table with multiple shards (i.e CW) these are combined into one table/weight. + Used in composability. + """ + for embedding_kernel in self._emb_modules: + for ( + table_name, + tbe_slice, + ) in embedding_kernel.named_parameters_by_table(): + yield (table_name, tbe_slice) + + def get_named_split_embedding_weights_snapshot( + self, + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for emb_module in self._emb_modules: + if isinstance(emb_module, KeyValueEmbeddingBag): + yield from emb_module.get_named_split_embedding_weights_snapshot() + + def flush(self) -> None: + for emb_module in self._emb_modules: + emb_module.flush() + + def purge(self) -> None: + for emb_module in self._emb_modules: + emb_module.purge() class MetaInferGroupedEmbeddingsLookup( - BaseEmbeddingLookup[SparseFeatures, torch.Tensor] + BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn ): """ meta embedding lookup module for inference since inference lookup has references @@ -388,58 +690,81 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: + # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, - ) -> BaseBatchedEmbedding: + shard_index: Optional[int] = None, + ) -> BaseBatchedEmbedding[ + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ]: return QuantBatchedEmbedding( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) - self._id_list_feature_splits: List[int] = [ + self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs ] - # return a dummy empty tensor when grouped_configs is empty - self.register_buffer( - "_dummy_embs_tensor", - torch.empty( - [0], - dtype=torch.float32, - device=device, - ), + self.grouped_configs = grouped_configs + self.device: Optional[str] = str(device) if device is not None else None + self.output_dtype: torch.dtype = ( + fused_params["output_dtype"].as_dtype() + if fused_params and "output_dtype" in fused_params + else torch.float16 ) - self.grouped_configs = grouped_configs + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return get_tbes_to_register_from_iterable(self._emb_modules) + + def embeddings_cat_empty_rank_handle_inference( + self, + embeddings: List[torch.Tensor], + dim: int = 0, + ) -> torch.Tensor: + if len(self.grouped_configs) == 0: + # return a dummy empty tensor when grouped_configs is empty + dev: Optional[torch.device] = ( + torch.device(self.device) if self.device is not None else None + ) + return torch.empty([0], dtype=self.output_dtype, device=dev) + elif len(self.grouped_configs) == 1: + return embeddings[0] + else: + return torch.cat(embeddings, dim=dim) def forward( self, - sparse_features: SparseFeatures, + sparse_features: KeyedJaggedTensor, ) -> torch.Tensor: - assert sparse_features.id_list_features is not None embeddings: List[torch.Tensor] = [] - id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits, + features_by_group = ( + [sparse_features] + if len(self._feature_splits) == 1 + else sparse_features.split( + self._feature_splits, + ) ) - for emb_op, features in zip(self._emb_modules, id_list_features_by_group): - embeddings.append(emb_op(features).view(-1)) + for i in range(len(self._emb_modules)): + # 2d embedding by nature + embeddings.append(self._emb_modules[i].forward(features_by_group[i])) - if len(embeddings) == 0: - # a hack for empty ranks - return self._dummy_embs_tensor - elif len(embeddings) == 1: - return embeddings[0] - else: - return torch.cat(embeddings) + return self.embeddings_cat_empty_rank_handle_inference(embeddings) # pyre-ignore [14] def state_dict( @@ -488,9 +813,17 @@ def named_buffers( for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) + def flush(self) -> None: + # not implemented + pass + + def purge(self) -> None: + # not implemented + pass + class MetaInferGroupedPooledEmbeddingsLookup( - BaseEmbeddingLookup[SparseFeatures, torch.Tensor] + BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn ): """ meta embedding bag lookup module for inference since inference lookup has references @@ -501,102 +834,104 @@ class MetaInferGroupedPooledEmbeddingsLookup( def __init__( self, grouped_configs: List[GroupedEmbeddingConfig], - grouped_score_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: + # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, - ) -> BaseBatchedEmbeddingBag: + shard_index: Optional[int] = None, + ) -> BaseBatchedEmbeddingBag[ + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ]: return QuantBatchedEmbeddingBag( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) - - self._score_emb_modules: nn.ModuleList = nn.ModuleList() - for config in grouped_score_configs: - self._score_emb_modules.append(_create_lookup(config, device)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) - self._id_list_feature_splits: List[int] = [ + self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs ] - self._id_score_list_feature_splits: List[int] = [ - config.num_features() for config in grouped_score_configs - ] - - # return a dummy empty tensor - # when grouped_configs and grouped_score_configs are empty - self.register_buffer( - "_dummy_embs_tensor", - torch.empty( - [0], - dtype=torch.float32, - device=device, - ), - ) self.grouped_configs = grouped_configs - self.grouped_score_configs = grouped_score_configs self._feature_processor = feature_processor + self.device: Optional[str] = str(device) if device is not None else None + self.output_dtype: torch.dtype = ( + fused_params["output_dtype"].as_dtype() + if fused_params and "output_dtype" in fused_params + else torch.float16 + ) + + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return get_tbes_to_register_from_iterable(self._emb_modules) + + def embeddings_cat_empty_rank_handle_inference( + self, + embeddings: List[torch.Tensor], + dim: int = 0, + ) -> torch.Tensor: + if len(self.grouped_configs) == 0: + # return a dummy empty tensor when grouped_configs is empty + dev: Optional[torch.device] = ( + torch.device(self.device) if self.device is not None else None + ) + return torch.empty([0], dtype=self.output_dtype, device=dev) + elif len(self.grouped_configs) == 1: + return embeddings[0] + else: + return torch.cat(embeddings, dim=dim) def forward( self, - sparse_features: SparseFeatures, + sparse_features: KeyedJaggedTensor, ) -> torch.Tensor: - assert ( - sparse_features.id_list_features is not None - or sparse_features.id_score_list_features is not None - ) - embeddings: List[torch.Tensor] = [] - if len(self._emb_modules) > 0: - assert sparse_features.id_list_features is not None - id_list_features_by_group = sparse_features.id_list_features.split( - self._id_list_feature_splits, + if len(self.grouped_configs) == 0: + # return a dummy empty tensor when grouped_configs is empty + return dummy_tensor( + sparse_features, + self.output_dtype, ) - for config, emb_op, features in zip( - self.grouped_configs, self._emb_modules, id_list_features_by_group - ): - if ( - config.has_feature_processor - and self._feature_processor is not None - and isinstance(self._feature_processor, BaseGroupedFeatureProcessor) - ): - features = self._feature_processor(features) - embeddings.append(emb_op(features)) - if len(self._score_emb_modules) > 0: - assert sparse_features.id_score_list_features is not None - id_score_list_features_by_group = ( - sparse_features.id_score_list_features.split( - self._id_score_list_feature_splits, - ) + + embeddings: List[torch.Tensor] = [] + features_by_group = ( + [sparse_features] + if len(self._feature_splits) == 1 + else sparse_features.split( + self._feature_splits, ) - for emb_op, features in zip( - self._score_emb_modules, id_score_list_features_by_group + ) + # syntax for torchscript + for i, (config, emb_op) in enumerate( + zip(self.grouped_configs, self._emb_modules) + ): + features = features_by_group[i] + if ( + config.has_feature_processor + and self._feature_processor is not None + and isinstance(self._feature_processor, BaseGroupedFeatureProcessor) ): - embeddings.append(emb_op(features)) + features = self._feature_processor(features) + embeddings.append(emb_op.forward(features)) - if len(embeddings) == 0: - # a hack for empty ranks - batch_size: int = ( - sparse_features.id_list_features.stride() - if sparse_features.id_list_features is not None - # pyre-fixme[16]: `Optional` has no attribute `stride`. - else sparse_features.id_score_list_features.stride() - ) - return self._dummy_embs_tensor.view(batch_size, 0) - elif len(embeddings) == 1: - return embeddings[0] - else: - return torch.cat(embeddings, dim=1) + return self.embeddings_cat_empty_rank_handle_inference( + embeddings, + dim=1, + ) # pyre-ignore [14] def state_dict( @@ -612,8 +947,6 @@ def state_dict( for emb_module in self._emb_modules: emb_module.state_dict(destination, prefix, keep_vars) - for emb_module in self._score_emb_modules: - emb_module.state_dict(destination, prefix, keep_vars) return destination @@ -624,9 +957,8 @@ def load_state_dict( state_dict: "OrderedDict[str, Union[ShardedTensor, torch.Tensor]]", strict: bool = True, ) -> _IncompatibleKeys: - m1, u1 = _load_state_dict(self._emb_modules, state_dict) - m2, u2 = _load_state_dict(self._score_emb_modules, state_dict) - return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2) + m, u = _load_state_dict(self._emb_modules, state_dict) + return _IncompatibleKeys(missing_keys=m, unexpected_keys=u) def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -637,8 +969,6 @@ def named_parameters( ) for emb_module in self._emb_modules: yield from emb_module.named_parameters(prefix, recurse) - for emb_module in self._score_emb_modules: - yield from emb_module.named_parameters(prefix, recurse) def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -649,26 +979,30 @@ def named_buffers( ) for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) - for emb_module in self._score_emb_modules: - yield from emb_module.named_buffers(prefix, recurse) + + def flush(self) -> None: + # not implemented + pass + + def purge(self) -> None: + # not implemented + pass class InferGroupedLookupMixin(ABC): def forward( self, - sparse_features: SparseFeaturesList, + input_dist_outputs: InputDistOutputs, ) -> List[torch.Tensor]: embeddings: List[torch.Tensor] = [] - for sparse_features_rank, embedding_lookup in zip( - sparse_features, + sparse_features = input_dist_outputs.features + # syntax for torchscript + for i, embedding_lookup in enumerate( # pyre-fixme[16] self._embedding_lookups_per_rank, ): - assert ( - sparse_features_rank.id_list_features is not None - or sparse_features_rank.id_score_list_features is not None - ) - embeddings.append(embedding_lookup(sparse_features_rank)) + sparse_features_rank = sparse_features[i] + embeddings.append(embedding_lookup.forward(sparse_features_rank)) return embeddings def state_dict( @@ -721,48 +1055,115 @@ def named_buffers( class InferGroupedPooledEmbeddingsLookup( InferGroupedLookupMixin, - BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]], + BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]], + TBEToRegisterMixIn, ): def __init__( self, grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], - grouped_score_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[ MetaInferGroupedPooledEmbeddingsLookup ] = [] + if get_propogate_device(): + device_type: str = ( + "cpu" if device is None else device.type + ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved + else: + device_type = ( + "meta" if device is not None and device.type == "meta" else "cuda" + ) + + self._is_empty_rank: List[bool] = [] for rank in range(world_size): - self._embedding_lookups_per_rank.append( - # TODO add position weighted module support - MetaInferGroupedPooledEmbeddingsLookup( - grouped_configs=grouped_configs_per_rank[rank], - grouped_score_configs=grouped_score_configs_per_rank[rank], - device=torch.device("cuda", rank), - fused_params=fused_params, - ) + empty_rank = len(grouped_configs_per_rank[rank]) == 0 + # Propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table across different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None ) + self._is_empty_rank.append(empty_rank) + if not empty_rank: + self._embedding_lookups_per_rank.append( + # TODO add position weighted module support + MetaInferGroupedPooledEmbeddingsLookup( + grouped_configs=grouped_configs_per_rank[rank], + device=rank_device(device_type, rank), + fused_params=fused_params, + shard_index=shard_index, + ) + ) + + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank) + + def forward( + self, + input_dist_outputs: InputDistOutputs, + ) -> List[torch.Tensor]: + embeddings: List[torch.Tensor] = [] + sparse_features = [ + input_dist_outputs.features[i] + for i, is_empty in enumerate(self._is_empty_rank) + if not is_empty + ] + # syntax for torchscript + for i, embedding_lookup in enumerate( + self._embedding_lookups_per_rank, + ): + sparse_features_rank = sparse_features[i] + embeddings.append(embedding_lookup.forward(sparse_features_rank)) + return embeddings class InferGroupedEmbeddingsLookup( InferGroupedLookupMixin, - BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]], + BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]], + TBEToRegisterMixIn, ): def __init__( self, grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] + + if get_propogate_device(): + device_type: str = ( + "cpu" if device is None else device.type + ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved + else: + device_type = ( + "meta" if device is not None and device.type == "meta" else "cuda" + ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], - device=torch.device("cuda", rank), + device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) + + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 7f89024e7..98fa2d15f 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -5,161 +5,253 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc -from dataclasses import dataclass, field -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar +import copy +from collections import defaultdict +from dataclasses import dataclass +from itertools import filterfalse +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union import torch -import torch.distributed as dist -from torch import nn +from torch import distributed as dist, nn from torchrec.distributed.dist_data import ( - KJTAllToAll, - KJTAllToAllIndicesAwaitable, - KJTOneToAll, + KJTAllToAllTensorsAwaitable, + SplitsAllToAllAwaitable, +) +from torchrec.distributed.embedding_dim_bucketer import ( + EmbDimBucketer, + EmbDimBucketerPolicy, ) from torchrec.distributed.embedding_types import ( BaseEmbeddingLookup, BaseGroupedFeatureProcessor, EmbeddingComputeKernel, + FeatureShardingMixIn, GroupedEmbeddingConfig, - ListOfSparseFeaturesList, + KJTList, + ListOfKJTList, ShardedEmbeddingTable, - SparseFeatures, - SparseFeaturesList, ) from torchrec.distributed.types import ( Awaitable, - FeatureShardingMixIn, - NoWait, + EmbeddingEvent, ParameterSharding, QuantizedCommCodecs, ShardMetadata, ) -from torchrec.modules.embedding_configs import ( - DataType, - EmbeddingTableConfig, - PoolingType, -) +from torchrec.distributed.utils import maybe_annotate_embedding_event +from torchrec.fx.utils import assert_fx_safe +from torchrec.modules.embedding_configs import EmbeddingTableConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable -class SparseFeaturesIndicesAwaitable(Awaitable[SparseFeatures]): - """ - Awaitable of sparse features redistributed with AlltoAll collective. +torch.fx.wrap("len") - Args: - id_list_features_awaitable (Optional[Awaitable[KeyedJaggedTensor]]): awaitable - of sharded id list features. - id_score_list_features_awaitable (Optional[Awaitable[KeyedJaggedTensor]]): - awaitable of sharded id score list features. - """ +CACHE_LOAD_FACTOR_STR: str = "cache_load_factor" +USE_ONE_TBE_PER_TABLE: str = "use_one_tbe_per_table" - def __init__( - self, - id_list_features_awaitable: Optional[Awaitable[KeyedJaggedTensor]], - id_score_list_features_awaitable: Optional[Awaitable[KeyedJaggedTensor]], - ) -> None: - super().__init__() - self._id_list_features_awaitable = id_list_features_awaitable - self._id_score_list_features_awaitable = id_score_list_features_awaitable - def _wait_impl(self) -> SparseFeatures: - """ - Syncs sparse features after AlltoAll. +# torch.Tensor.to can not be fx symbolic traced as it does not go through __torch_dispatch__ => fx.wrap it +@torch.fx.wrap +def _fx_wrap_tensor_to_device_dtype( + t: torch.Tensor, tensor_device_dtype: torch.Tensor +) -> torch.Tensor: + return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) - Returns: - SparseFeatures: synced sparse features. - """ - return SparseFeatures( - id_list_features=self._id_list_features_awaitable.wait() - if self._id_list_features_awaitable is not None - else None, - id_score_list_features=self._id_score_list_features_awaitable.wait() - if self._id_score_list_features_awaitable is not None - else None, +@torch.fx.wrap +def _fx_wrap_optional_tensor_to_device_dtype( + t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor +) -> Optional[torch.Tensor]: + if t is None: + return None + return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) + + +@torch.fx.wrap +def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]: + return ( + torch.tensor( + kjt.stride_per_key(), device=kjt.device(), dtype=kjt.lengths().dtype ) + if kjt.variable_stride_per_key() + else None + ) -class SparseFeaturesLengthsAwaitable(Awaitable[SparseFeaturesIndicesAwaitable]): - """ - Awaitable of sparse features indices distribution. +@torch.fx.wrap +def _fx_wrap_max_B(kjt: KeyedJaggedTensor) -> int: + return max(kjt.stride_per_key()) if kjt.variable_stride_per_key() else -1 - Args: - id_list_features_awaitable (Optional[Awaitable[KJTAllToAllIndicesAwaitable]]): - awaitable of sharded id list features indices AlltoAll. Waiting on this - value will trigger indices AlltoAll (waiting again will yield final AlltoAll - results). - id_score_list_features_awaitable - (Optional[Awaitable[KJTAllToAllIndicesAwaitable]]): - awaitable of sharded id score list features indices AlltoAll. Waiting on - this value will trigger indices AlltoAll (waiting again will yield the final - AlltoAll results). - """ - def __init__( - self, - id_list_features_awaitable: Optional[Awaitable[KJTAllToAllIndicesAwaitable]], - id_score_list_features_awaitable: Optional[ - Awaitable[KJTAllToAllIndicesAwaitable] - ], - ) -> None: - super().__init__() - self._id_list_features_awaitable = id_list_features_awaitable - self._id_score_list_features_awaitable = id_score_list_features_awaitable +@torch.fx.wrap +def _fx_wrap_stride(kjt: KeyedJaggedTensor) -> Optional[int]: + return None if kjt.variable_stride_per_key() else kjt.stride() - def _wait_impl(self) -> SparseFeaturesIndicesAwaitable: - """ - Gets lengths of AlltoAll results, instantiates `SparseFeaturesIndicesAwaitable` for - indices AlltoAll. - Returns: - SparseFeaturesIndicesAwaitable. - """ - return SparseFeaturesIndicesAwaitable( - id_list_features_awaitable=self._id_list_features_awaitable.wait() - if self._id_list_features_awaitable is not None - else None, - id_score_list_features_awaitable=self._id_score_list_features_awaitable.wait() - if self._id_score_list_features_awaitable is not None - else None, - ) +@torch.fx.wrap +def _fx_wrap_stride_per_key_per_rank( + kjt: KeyedJaggedTensor, num_buckets: int +) -> Optional[List[List[int]]]: + return ( + kjt.stride_per_key_per_rank() * num_buckets + if kjt.variable_stride_per_key() + else None + ) + + +@torch.fx.wrap +def _fx_wrap_gen_list_n_times(ls: List[str], n: int) -> List[str]: + # Syntax for dynamo (instead of generator kjt.keys() * num_buckets) + ret: List[str] = [] + for _ in range(n): + ret.extend(ls) + return ret + + +@torch.fx.wrap +def _fx_wrap_gen_keys(ls: List[str], n: int) -> List[str]: + # Syntax for dynamo (instead of generator kjt.keys() * num_buckets) + return ls * n + + +@torch.fx.wrap +def _fx_wrap_opt_to_nonopt_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: + assert optional is not None, "Expected optional to be non-None Tensor" + return optional + + +@torch.fx.wrap +def _fx_wrap_seq_block_bucketize_sparse_features_inference( + kjt: KeyedJaggedTensor, + num_buckets: int, + block_sizes: torch.Tensor, + bucketize_pos: bool = False, + block_bucketize_pos: Optional[List[torch.Tensor]] = None, + total_num_blocks: Optional[torch.Tensor] = None, + keep_original_indices: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + unbucketize_permute, + bucket_mapping, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_inference( + kjt.lengths().view(-1), + kjt.values(), + bucketize_pos=bucketize_pos, + sequence=True, + block_sizes=block_sizes, + total_num_blocks=total_num_blocks, + my_size=num_buckets, + weights=kjt.weights_or_none(), + max_B=_fx_wrap_max_B(kjt), + block_bucketize_pos=block_bucketize_pos, + return_bucket_mapping=True, + keep_orig_idx=keep_original_indices, + ) + + return ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + _fx_wrap_opt_to_nonopt_tensor(unbucketize_permute), + _fx_wrap_opt_to_nonopt_tensor(bucket_mapping), + ) + + +@torch.fx.wrap +def _fx_wrap_none_seq_block_bucketize_sparse_features_inference( + kjt: KeyedJaggedTensor, + num_buckets: int, + block_sizes: torch.Tensor, + bucketize_pos: bool = False, + block_bucketize_pos: Optional[List[torch.Tensor]] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + _, + _, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_inference( + kjt.lengths().view(-1), + kjt.values(), + bucketize_pos=bucketize_pos, + sequence=False, + block_sizes=block_sizes, + my_size=num_buckets, + weights=kjt.weights_or_none(), + max_B=_fx_wrap_max_B(kjt), + block_bucketize_pos=block_bucketize_pos, + return_bucket_mapping=False, + ) + + return ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + ) def bucketize_kjt_before_all2all( kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, + total_num_blocks: Optional[torch.Tensor] = None, output_permute: bool = False, bucketize_pos: bool = False, + block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, + keep_original_indices: bool = False, ) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: """ Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, `lengths` are readjusted based on the bucketization results. Note: This function should be used only for row-wise sharding before calling - `SparseFeaturesAllToAll`. + `KJTAllToAll`. Args: num_buckets (int): number of buckets to bucketize the values into. block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. + total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization output_permute (bool): output the memory location mapping from the unbucketized values to bucketized values or not. bucketize_pos (bool): output the changed position of the bucketized values or not. + block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature. + keep_original_indices (bool): whether to keep the original indices or not. Returns: Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: the bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized value. """ num_features = len(kjt.keys()) - assert ( - block_sizes.numel() == num_features - ), f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received." + assert_fx_safe( + block_sizes.numel() == num_features, + f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", + ) - # kernel expects them to be same type, cast to avoid type mismatch - block_sizes_new_type = block_sizes.type(kjt.values().type()) ( bucketized_lengths, bucketized_indices, @@ -171,20 +263,37 @@ def bucketize_kjt_before_all2all( kjt.values(), bucketize_pos=bucketize_pos, sequence=output_permute, - block_sizes=block_sizes_new_type, + block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()), + total_num_blocks=( + _fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values()) + if total_num_blocks is not None + else None + ), my_size=num_buckets, weights=kjt.weights_or_none(), + batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), + max_B=_fx_wrap_max_B(kjt), + block_bucketize_pos=( + [ + _fx_wrap_tensor_to_device_dtype(pos, kjt.values()) + for pos in block_bucketize_row_pos + ] + if block_bucketize_row_pos is not None + else None + ), + keep_orig_idx=keep_original_indices, ) return ( KeyedJaggedTensor( # duplicate keys will be resolved by AllToAll - keys=kjt.keys() * num_buckets, + keys=_fx_wrap_gen_list_n_times(kjt.keys(), num_buckets), values=bucketized_indices, weights=pos if bucketize_pos else bucketized_weights, lengths=bucketized_lengths.view(-1), offsets=None, - stride=kjt.stride(), + stride=_fx_wrap_stride(kjt), + stride_per_key_per_rank=_fx_wrap_stride_per_key_per_rank(kjt, num_buckets), length_per_key=None, offset_per_key=None, index_per_key=None, @@ -193,442 +302,635 @@ def bucketize_kjt_before_all2all( ) -class SparseFeaturesAllToAll(nn.Module): +def bucketize_kjt_inference( + kjt: KeyedJaggedTensor, + num_buckets: int, + block_sizes: torch.Tensor, + total_num_buckets: Optional[torch.Tensor] = None, + bucketize_pos: bool = False, + block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, + is_sequence: bool = False, + keep_original_indices: bool = False, +) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Redistributes sparse features to a `ProcessGroup` utilizing an AlltoAll collective. - - Args: - pg (dist.ProcessGroup): process group for AlltoAll communication. - id_list_features_per_rank (List[int]): number of id list features to send to - each rank. - id_score_list_features_per_rank (List[int]): number of id score list features to - send to each rank - device (Optional[torch.device]): device on which buffers will be allocated. - stagger (int): stagger value to apply to recat tensor, see `_recat` function for - more detail. - variable_batch_size (bool): variable batch size in each rank. - - Example:: - - id_list_features_per_rank = [2, 1] - id_score_list_features_per_rank = [1, 3] - sfa2a = SparseFeaturesAllToAll( - pg, - id_list_features_per_rank, - id_score_list_features_per_rank - ) - awaitable = sfa2a(rank0_input: SparseFeatures) - - # where: - # rank0_input.id_list_features is KeyedJaggedTensor holding - - # 0 1 2 - # 'A' [A.V0] None [A.V1, A.V2] - # 'B' None [B.V0] [B.V1] - # 'C' [C.V0] [C.V1] None - - # rank1_input.id_list_features is KeyedJaggedTensor holding - - # 0 1 2 - # 'A' [A.V3] [A.V4] None - # 'B' None [B.V2] [B.V3, B.V4] - # 'C' [C.V2] [C.V3] None - - # rank0_input.id_score_list_features is KeyedJaggedTensor holding + Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, + `lengths` are readjusted based on the bucketization results. - # 0 1 2 - # 'A' [A.V0] None [A.V1, A.V2] - # 'B' None [B.V0] [B.V1] - # 'C' [C.V0] [C.V1] None - # 'D' None [D.V0] None + Note: This function should be used only for row-wise sharding before calling + `KJTAllToAll`. - # rank1_input.id_score_list_features is KeyedJaggedTensor holding + Args: + num_buckets (int): number of buckets to bucketize the values into. + block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. + total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization + bucketize_pos (bool): output the changed position of the bucketized values or + not. + block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature. + is_sequence (bool): whether the input is a sequence feature or not. - # 0 1 2 - # 'A' [A.V3] [A.V4] None - # 'B' None [B.V2] [B.V3, B.V4] - # 'C' [C.V2] [C.V3] None - # 'D' [D.V1] [D.V2] [D.V3, D.V4] + Returns: + Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: the bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized value. + """ - rank0_output: SparseFeatures = awaitable.wait() + num_features = len(kjt.keys()) + assert_fx_safe( + block_sizes.numel() == num_features, + f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", + ) + block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()) + total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype( + total_num_buckets, kjt.values() + ) + unbucketize_permute = None + bucket_mapping = None + if is_sequence: + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + unbucketize_permute, + bucket_mapping, + ) = _fx_wrap_seq_block_bucketize_sparse_features_inference( + kjt, + num_buckets=num_buckets, + block_sizes=block_sizes_new_type, + total_num_blocks=total_num_buckets_new_type, + bucketize_pos=bucketize_pos, + block_bucketize_pos=block_bucketize_row_pos, + keep_original_indices=keep_original_indices, + ) + else: + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + ) = _fx_wrap_none_seq_block_bucketize_sparse_features_inference( + kjt, + num_buckets=num_buckets, + block_sizes=block_sizes_new_type, + bucketize_pos=bucketize_pos, + block_bucketize_pos=block_bucketize_row_pos, + ) - # rank0_output.id_list_features is KeyedJaggedTensor holding + return ( + KeyedJaggedTensor( + keys=_fx_wrap_gen_keys(kjt.keys(), num_buckets), + values=bucketized_indices, + weights=pos if bucketize_pos else bucketized_weights, + lengths=bucketized_lengths.view(-1), + ), + unbucketize_permute, + bucket_mapping, + ) - # 0 1 2 3 4 5 - # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None - # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] - # rank1_output.id_list_features is KeyedJaggedTensor holding - # 0 1 2 3 4 5 - # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None +def _get_weighted_avg_cache_load_factor( + embedding_tables: List[ShardedEmbeddingTable], +) -> Optional[float]: + """ + Calculate the weighted average cache load factor of all tables. The cache + load factors are weighted by the hash size of each table. + """ + cache_load_factor_sum: float = 0.0 + weight: int = 0 + + for table in embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + and table.fused_params + and CACHE_LOAD_FACTOR_STR in table.fused_params + ): + cache_load_factor_sum += ( + table.fused_params[CACHE_LOAD_FACTOR_STR] * table.num_embeddings + ) + weight += table.num_embeddings - # rank0_output.id_score_list_features is KeyedJaggedTensor holding + # if no fused_uvm_caching tables, return default cache load factor + if weight == 0: + return None - # 0 1 2 3 4 5 - # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None + return cache_load_factor_sum / weight - # rank1_output.id_score_list_features is KeyedJaggedTensor holding - # 0 1 2 3 4 5 - # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] - # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None - # 'D None [D.V0] None [D.V1] [D.V2] [D.V3, D.V4] +def _get_grouping_fused_params( + fused_params: Optional[Dict[str, Any]], + name: str, +) -> Optional[Dict[str, Any]]: """ + Only shallow copy the fused params we need for grouping tables into TBEs. In + particular, we do not copy cache_load_factor. + """ + grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params) - def __init__( - self, - pg: dist.ProcessGroup, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - device: Optional[torch.device] = None, - stagger: int = 1, - variable_batch_size: bool = False, - ) -> None: - super().__init__() - self._id_list_features_all2all: KJTAllToAll = KJTAllToAll( - pg=pg, - splits=id_list_features_per_rank, - device=device, - stagger=stagger, - variable_batch_size=variable_batch_size, - ) - self._id_score_list_features_all2all: KJTAllToAll = KJTAllToAll( - pg=pg, - splits=id_score_list_features_per_rank, - device=device, - stagger=stagger, - variable_batch_size=variable_batch_size, - ) - - def forward( - self, - sparse_features: SparseFeatures, - ) -> Awaitable[SparseFeaturesIndicesAwaitable]: - """ - Sends sparse features to relevant ProcessGroup ranks. Instantiates lengths - AlltoAll. - First wait will get lengths AlltoAll results, then issues indices AlltoAll. - Second wait will get indices AlltoAll results. + if not grouping_fused_params: + return grouping_fused_params - Args: - sparse_features (SparseFeatures): sparse features to redistribute. + if CACHE_LOAD_FACTOR_STR in grouping_fused_params: + del grouping_fused_params[CACHE_LOAD_FACTOR_STR] - Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. - """ + if grouping_fused_params.get(USE_ONE_TBE_PER_TABLE, False): + # Replace with unique value to force it into singleton group. + # Name is used as unique value so we won't group multiple shard belonging + # to the same embedding table separately. + grouping_fused_params[USE_ONE_TBE_PER_TABLE] = name - return SparseFeaturesLengthsAwaitable( - id_list_features_awaitable=self._id_list_features_all2all.forward( - sparse_features.id_list_features - ) - if sparse_features.id_list_features is not None - else None, - id_score_list_features_awaitable=self._id_score_list_features_all2all.forward( - sparse_features.id_score_list_features - ) - if sparse_features.id_score_list_features is not None - else None, - ) + return grouping_fused_params -class SparseFeaturesOneToAll(nn.Module): +def _get_compute_kernel_type( + compute_kernel: EmbeddingComputeKernel, +) -> EmbeddingComputeKernel: """ - Redistributes sparse features to all devices. - - Args: - id_list_features_per_rank (List[int]): number of id list features to send to - each rank. - id_score_list_features_per_rank (List[int]): number of id score list features to - send to each rank - world_size (int): number of devices in the topology. + Return the compute kernel type for the given compute kernel. """ + compute_kernel_type = compute_kernel + if compute_kernel_type in [ + EmbeddingComputeKernel.FUSED_UVM, + EmbeddingComputeKernel.FUSED_UVM_CACHING, + ]: + compute_kernel_type = EmbeddingComputeKernel.FUSED + elif compute_kernel_type in [ + EmbeddingComputeKernel.QUANT_UVM, + EmbeddingComputeKernel.QUANT_UVM_CACHING, + ]: + compute_kernel_type = EmbeddingComputeKernel.QUANT + return compute_kernel_type + + +def _prefetch_and_cached( + table: ShardedEmbeddingTable, +) -> bool: + """ + Return if this embedding use hbm as cache. In this case we might want to use + bucketizer to group by dimension for memory efficiency. + """ + if table.compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE, + }: + return True - def __init__( - self, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - world_size: int, - ) -> None: - super().__init__() - self._world_size = world_size - self._id_list_features_one2all: KJTOneToAll = KJTOneToAll( - id_list_features_per_rank, - world_size, - ) - self._id_score_list_features_one2all: KJTOneToAll = KJTOneToAll( - id_score_list_features_per_rank, world_size - ) - - def forward( - self, - sparse_features: SparseFeatures, - ) -> Awaitable[SparseFeaturesList]: - """ - Performs OnetoAll operation on sparse features. - - Args: - sparse_features (SparseFeatures): sparse features to redistribute. + return ( + table.compute_kernel + in [ + EmbeddingComputeKernel.FUSED_UVM_CACHING, + EmbeddingComputeKernel.QUANT_UVM_CACHING, + ] + and table.fused_params is not None + and "prefetch_pipeline" in table.fused_params + and table.fused_params["prefetch_pipeline"] + ) - Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. - """ - return NoWait( - SparseFeaturesList( - [ - SparseFeatures( - id_list_features=id_list_features, - id_score_list_features=id_score_list_features, - ) - for id_list_features, id_score_list_features in zip( - self._id_list_features_one2all.forward( - sparse_features.id_list_features - ).wait() - if sparse_features.id_list_features is not None - else [None] * self._world_size, - self._id_score_list_features_one2all.forward( - sparse_features.id_score_list_features - ).wait() - if sparse_features.id_score_list_features is not None - else [None] * self._world_size, - ) - ] - ) - ) +def _all_tables_are_quant_kernel( + tables: List[ShardedEmbeddingTable], +) -> bool: + """ + Return if all tables have quant compute kernel. + """ + return all(table.compute_kernel == EmbeddingComputeKernel.QUANT for table in tables) -# group tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel. +# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`. def group_tables( tables_per_rank: List[List[ShardedEmbeddingTable]], -) -> Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]: +) -> List[List[GroupedEmbeddingConfig]]: """ - Groups tables by `DataType`, `PoolingType`, `Weighted`, and `EmbeddingComputeKernel`. + Groups tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`. Args: - tables_per_rank (List[List[ShardedEmbeddingTable]]): list of sharding embedding - tables per rank. + tables_per_rank (List[List[ShardedEmbeddingTable]]): list of sharded embedding + tables per rank with consistent weightedness. Returns: - Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]: per rank list of GroupedEmbeddingConfig for unscored and scored features. + List[List[GroupedEmbeddingConfig]]: per rank list of GroupedEmbeddingConfig for features. """ def _group_tables_per_rank( embedding_tables: List[ShardedEmbeddingTable], - ) -> Tuple[List[GroupedEmbeddingConfig], List[GroupedEmbeddingConfig]]: + ) -> List[GroupedEmbeddingConfig]: grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] - score_grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] - # add fused params: - fused_params_groups = [] + # We use different dim-bucketing policy for different cases. + # If prefetch is off, all table (regardless of cache status or dimension) will be grouped together (SINGLE_BUCKET) + # If prefetch is on, + # Cached vs noncached tables will be separated, even if they have the same dimension + # For two cached tables, if they have different dimension they shall be separated, otherwise they'll be grouped (ALL_BUCKETS) + # For two noncached tables, they'll be grouped regardless of dimension (SINGLE_BUCKET) + prefetch_cached_dim_bucketer = EmbDimBucketer( + list(filter(_prefetch_and_cached, embedding_tables)), + EmbDimBucketerPolicy.ALL_BUCKETS, + ) + non_prefetch_cached_dim_bucketer = EmbDimBucketer( + list(filterfalse(_prefetch_and_cached, embedding_tables)), + EmbDimBucketerPolicy.SINGLE_BUCKET, + ) + + # all embedding tables have the same weight status + is_weighted = ( + embedding_tables[0].is_weighted if len(embedding_tables) > 0 else False + ) + + # Collect groups + groups = defaultdict(list) + grouping_keys = [] + # Assumes all compute kernels within tables are the same + is_inference = _all_tables_are_quant_kernel(embedding_tables) for table in embedding_tables: - if table.fused_params is None: - table.fused_params = {} - if table.fused_params not in fused_params_groups: - fused_params_groups.append(table.fused_params) - - compute_kernels = [ - EmbeddingComputeKernel.DENSE, - EmbeddingComputeKernel.FUSED, - EmbeddingComputeKernel.QUANT, - ] + bucketer = ( + prefetch_cached_dim_bucketer + if _prefetch_and_cached(table) + else non_prefetch_cached_dim_bucketer + ) + group_fused_params = ( + _get_grouping_fused_params(table.fused_params, table.name) or {} + ) + grouping_key = ( + table.data_type if not is_inference else None, + table.pooling, + table.has_feature_processor, + tuple(sorted(group_fused_params.items())), + _get_compute_kernel_type(table.compute_kernel), + # TODO: Unit test to check if table.data_type affects table grouping + bucketer.get_bucket( + table.local_cols, + table.data_type, + ), + _prefetch_and_cached(table), + ) + # micromanage the order of we traverse the groups to ensure backwards compatibility + if grouping_key not in groups: + grouping_keys.append(grouping_key) + groups[grouping_key].append(table) + + for grouping_key in grouping_keys: + ( + data_type, + pooling, + has_feature_processor, + fused_params_tuple, + compute_kernel_type, + _, + _, + ) = grouping_key + grouped_tables = groups[grouping_key] + # remove non-native fused params + per_tbe_fused_params = { + k: v + for k, v in fused_params_tuple + if k not in ["_batch_key", USE_ONE_TBE_PER_TABLE] + } + cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables) + if cache_load_factor is not None: + per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor + + grouped_embedding_configs.append( + GroupedEmbeddingConfig( + data_type=data_type, + pooling=pooling, + is_weighted=is_weighted, + has_feature_processor=has_feature_processor, + compute_kernel=compute_kernel_type, + embedding_tables=grouped_tables, + fused_params=per_tbe_fused_params, + ) + ) + return grouped_embedding_configs - for data_type in DataType: - for pooling in PoolingType: - for is_weighted in [True, False]: - # remove this when finishing migration - for has_feature_processor in [False, True]: - for fused_params_group in fused_params_groups: - for compute_kernel in compute_kernels: - grouped_tables: List[ShardedEmbeddingTable] = [] - grouped_score_tables: List[ShardedEmbeddingTable] = [] - for table in embedding_tables: - compute_kernel_type = table.compute_kernel - if table.compute_kernel in [ - EmbeddingComputeKernel.FUSED_UVM, - EmbeddingComputeKernel.FUSED_UVM_CACHING, - ]: - compute_kernel_type = ( - EmbeddingComputeKernel.FUSED - ) - elif table.compute_kernel in [ - EmbeddingComputeKernel.QUANT_UVM, - EmbeddingComputeKernel.QUANT_UVM_CACHING, - ]: - compute_kernel_type = ( - EmbeddingComputeKernel.QUANT - ) - if ( - table.data_type == data_type - and table.pooling == pooling - and table.is_weighted == is_weighted - and table.has_feature_processor - == has_feature_processor - and compute_kernel_type == compute_kernel - and table.fused_params == fused_params_group - ): - if table.is_weighted: - grouped_score_tables.append(table) - else: - grouped_tables.append(table) - - if fused_params_group is None: - fused_params_group = {} - - if grouped_tables: - grouped_embedding_configs.append( - GroupedEmbeddingConfig( - data_type=data_type, - pooling=pooling, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - compute_kernel=compute_kernel, - embedding_tables=grouped_tables, - fused_params={ - k: v - for k, v in fused_params_group.items() - if k - not in [ - "_batch_key" - ] # drop '_batch_key' not a native fused param - }, - ) - ) - if grouped_score_tables: - score_grouped_embedding_configs.append( - GroupedEmbeddingConfig( - data_type=data_type, - pooling=pooling, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - compute_kernel=compute_kernel, - embedding_tables=grouped_score_tables, - fused_params={ - k: v - for k, v in fused_params_group.items() - if k - not in [ - "_batch_key" - ] # drop '_batch_key', not a native fused param - }, - ) - ) - return grouped_embedding_configs, score_grouped_embedding_configs + table_weightedness = [ + table.is_weighted for tables in tables_per_rank for table in tables + ] + assert all(table_weightedness) or not any(table_weightedness) grouped_embedding_configs_by_rank: List[List[GroupedEmbeddingConfig]] = [] - score_grouped_embedding_configs_by_rank: List[List[GroupedEmbeddingConfig]] = [] for tables in tables_per_rank: - ( - grouped_embedding_configs, - score_grouped_embedding_configs, - ) = _group_tables_per_rank(tables) + grouped_embedding_configs = _group_tables_per_rank(tables) grouped_embedding_configs_by_rank.append(grouped_embedding_configs) - score_grouped_embedding_configs_by_rank.append(score_grouped_embedding_configs) - return ( - grouped_embedding_configs_by_rank, - score_grouped_embedding_configs_by_rank, - ) + return grouped_embedding_configs_by_rank + + +C = TypeVar("C", bound=Multistreamable) +T = TypeVar("T") -class SparseFeaturesListAwaitable(Awaitable[SparseFeaturesList]): + +class KJTListAwaitable(Awaitable[KJTList]): """ - Awaitable of SparseFeaturesList. + Awaitable of KJTList. Args: - awaitables (List[Awaitable[SparseFeatures]]): list of `Awaitable` of sparse + awaitables (List[Awaitable[KeyedJaggedTensor]]): list of `Awaitable` of sparse features. + ctx (C): sharding context to save the batch size info from the KJT for the + embedding AlltoAll. """ def __init__( self, - awaitables: List[Awaitable[SparseFeatures]], + awaitables: List[Awaitable[KeyedJaggedTensor]], + ctx: C, ) -> None: super().__init__() self.awaitables = awaitables + self.ctx = ctx - def _wait_impl(self) -> SparseFeaturesList: + def _wait_impl(self) -> KJTList: """ - Syncs sparse features in `SparseFeaturesList`. + Syncs KJTs in `KJTList`. Returns: - SparseFeaturesList: synced `SparseFeaturesList`. + KJTList: synced `KJTList`. """ - return SparseFeaturesList([w.wait() for w in self.awaitables]) + # Syntax: no list comprehension usage for dynamo + kjts = [] + for w in self.awaitables: + kjts.append(w.wait()) + + _set_sharding_context_post_a2a(kjts, self.ctx) + return KJTList(kjts) + + +def _set_sharding_context_post_a2a( + kjts: List[KeyedJaggedTensor], + ctx: C, +) -> None: + for kjt, sharding_context in zip(kjts, getattr(ctx, "sharding_contexts", [])): + if ( + hasattr(sharding_context, "batch_size_per_rank_per_feature") + and kjt.variable_stride_per_key() + and kjt.stride_per_key_per_rank() + ): + strides = kjt.stride_per_key_per_rank() + sharding_context.batch_size_per_rank_per_feature = [ + [strides[i][j] for i in range(len(strides))] + for j in range(len(strides[0])) + ] + + +def _set_sharding_context_intra_a2a( + tensors_awaitables: List[Awaitable[KeyedJaggedTensor]], + ctx: C, +) -> None: + for awaitable, sharding_context in zip( + tensors_awaitables, + getattr(ctx, "sharding_contexts", []), + ): + if isinstance(awaitable, KJTAllToAllTensorsAwaitable): + if hasattr(sharding_context, "input_splits"): + sharding_context.input_splits = awaitable._input_splits["values"] + if hasattr(sharding_context, "output_splits"): + sharding_context.output_splits = awaitable._output_splits["values"] + if hasattr(sharding_context, "sparse_features_recat"): + sharding_context.sparse_features_recat = awaitable._recat + if ( + hasattr(sharding_context, "batch_size_per_rank") + and awaitable._stride_per_rank is not None + ): + sharding_context.batch_size_per_rank = awaitable._stride_per_rank + + +def _split(flat_list: List[T], splits: List[int]) -> List[List[T]]: + return [ + flat_list[sum(splits[:i]) : sum(splits[:i]) + n] for i, n in enumerate(splits) + ] + + +class KJTListSplitsAwaitable(Awaitable[Awaitable[KJTList]], Generic[C]): + """ + Awaitable of Awaitable of KJTList. + + Args: + awaitables (List[Awaitable[Awaitable[KeyedJaggedTensor]]]): result from calling + forward on `KJTAllToAll` with sparse features to redistribute. + ctx (C): sharding context to save the metadata from the input dist to for the + embedding AlltoAll. + """ + def __init__( + self, + awaitables: List[Awaitable[Awaitable[KeyedJaggedTensor]]], + ctx: C, + module_fqn: Optional[str] = None, + sharding_types: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.awaitables = awaitables + self.ctx = ctx + self._module_fqn = module_fqn + self._sharding_types = sharding_types -class SparseFeaturesListIndicesAwaitable(Awaitable[List[Awaitable[SparseFeatures]]]): + def _wait_impl(self) -> KJTListAwaitable: + """ + Calls first wait on the awaitable of awaitable of sparse features and updates + the context with metadata from the tensors awaitable. + + The first wait gets the result of splits AlltoAll and returns the tensors + awaitable. + + Returns: + KJTListAwaitable: awaitables for tensors of the sparse features. + """ + tensors_awaitables = [] + + for i, w in enumerate(self.awaitables): + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST_WAIT, + self._module_fqn, + self._sharding_types[i] if self._sharding_types else None, + ): + tensors_awaitables.append(w.wait()) + + _set_sharding_context_intra_a2a(tensors_awaitables, self.ctx) + return KJTListAwaitable(tensors_awaitables, self.ctx) + + +@dataclass +class KJTSplitsAllToAllMeta: + pg: dist.ProcessGroup + _input: KeyedJaggedTensor + splits: List[int] + splits_tensors: List[torch.Tensor] + input_splits: List[List[int]] + input_tensors: List[torch.Tensor] + labels: List[str] + keys: List[str] + device: torch.device + stagger: int + + +class FusedKJTListSplitsAwaitable(Awaitable[List[KJTListAwaitable]]): + def __init__( + self, + requests: List[KJTListSplitsAwaitable[C]], + contexts: List[C], + pg: Optional[dist.ProcessGroup], + ) -> None: + super().__init__() + self._contexts = contexts + self._awaitables: List[ + Union[KJTSplitsAllToAllMeta, Awaitable[Awaitable[KeyedJaggedTensor]]] + ] = [awaitable for request in requests for awaitable in request.awaitables] + self._output_lengths: List[int] = [ + len(request.awaitables) for request in requests + ] + self._lengths: List[int] = [ + ( + len(awaitable.splits_tensors) + if isinstance(awaitable, KJTSplitsAllToAllMeta) + else 0 + ) + for awaitable in self._awaitables + ] + splits_tensors = [ + splits_tensor + for awaitable in self._awaitables + for splits_tensor in ( + awaitable.splits_tensors + if isinstance(awaitable, KJTSplitsAllToAllMeta) + else [] + ) + ] + self._splits_awaitable: Optional[SplitsAllToAllAwaitable] = ( + SplitsAllToAllAwaitable( + input_tensors=splits_tensors, + pg=pg, + ) + if splits_tensors and pg is not None + else None + ) + + def _wait_impl(self) -> List[KJTListAwaitable]: + if self._splits_awaitable: + splits_list = self._splits_awaitable.wait() + splits_per_awaitable = _split(splits_list, self._lengths) + else: + splits_per_awaitable = [[] for _ in range(len(self._lengths))] + tensors_awaitables = [] + for splits, awaitable in zip(splits_per_awaitable, self._awaitables): + if not splits: # NoWait + assert isinstance(awaitable, Awaitable) + tensors_awaitables.append(awaitable.wait()) + continue + assert isinstance(awaitable, KJTSplitsAllToAllMeta) + if awaitable._input.variable_stride_per_key(): + output_splits = splits + stride_per_rank = None + else: + output_splits = splits[:-1] + stride_per_rank = splits[-1] + tensors_awaitables.append( + KJTAllToAllTensorsAwaitable( + pg=awaitable.pg, + input=awaitable._input, + splits=awaitable.splits, + input_splits=awaitable.input_splits, + output_splits=output_splits, + input_tensors=awaitable.input_tensors, + labels=awaitable.labels, + keys=awaitable.keys, + device=awaitable.device, + stagger=awaitable.stagger, + stride_per_rank=stride_per_rank, + ) + ) + output = [] + awaitables_per_output = _split(tensors_awaitables, self._output_lengths) + for awaitables, ctx in zip(awaitables_per_output, self._contexts): + _set_sharding_context_intra_a2a(awaitables, ctx) + output.append(KJTListAwaitable(awaitables, ctx)) + return output + + +class ListOfKJTListAwaitable(Awaitable[ListOfKJTList]): """ - Handles the first wait for a list of two-layer awaitables of `SparseFeatures`. - Wait on this module will get lengths AlltoAll results for each `SparseFeatures`, and - instantiate its indices AlltoAll. + This module handles the tables-wise sharding input features distribution for + inference. Args: - awaitables (List[Awaitable[Awaitable[SparseFeatures]]]): list of `Awaitable` of - `Awaitable` sparse features. + awaitables (List[Awaitable[KJTList]]): list of `Awaitable` of `KJTList`. """ def __init__( self, - awaitables: List[Awaitable[Awaitable[SparseFeatures]]], + awaitables: List[Awaitable[KJTList]], ) -> None: super().__init__() self.awaitables = awaitables - def _wait_impl(self) -> List[Awaitable[SparseFeatures]]: + def _wait_impl(self) -> ListOfKJTList: """ - Syncs sparse features in SparseFeaturesList. + Syncs sparse features in list of KJTList. Returns: - List[Awaitable[SparseFeatures]] - """ + ListOfKJTList: synced `ListOfKJTList`. - return [m.wait() for m in self.awaitables] + """ + return ListOfKJTList([w.wait() for w in self.awaitables]) -class ListOfSparseFeaturesListAwaitable(Awaitable[ListOfSparseFeaturesList]): +class ListOfKJTListSplitsAwaitable(Awaitable[Awaitable[ListOfKJTList]]): """ - This module handles the tables-wise sharding input features distribution for inference. - For inference, we currently do not separate lengths from indices. + Awaitable of Awaitable of ListOfKJTList. Args: - awaitables (List[Awaitable[SparseFeaturesList]]): list of `Awaitable` of - `SparseFeaturesList`. + awaitables (List[Awaitable[Awaitable[KJTList]]]): list of `Awaitable` + of `Awaitable` of sparse features list. """ def __init__( self, - awaitables: List[Awaitable[SparseFeaturesList]], + awaitables: List[Awaitable[Awaitable[KJTList]]], ) -> None: super().__init__() self.awaitables = awaitables - def _wait_impl(self) -> ListOfSparseFeaturesList: + def _wait_impl(self) -> Awaitable[ListOfKJTList]: """ - Syncs sparse features in List of SparseFeaturesList. + Calls first wait on the awaitable of awaitable of ListOfKJTList. Returns: - ListOfSparseFeaturesList: synced `ListOfSparseFeaturesList`. + Awaitable[ListOfKJTList]: awaitable of `ListOfKJTList`. """ - return ListOfSparseFeaturesList([w.wait() for w in self.awaitables]) + return ListOfKJTListAwaitable([w.wait() for w in self.awaitables]) -C = TypeVar("C", bound=Multistreamable) F = TypeVar("F", bound=Multistreamable) T = TypeVar("T") W = TypeVar("W") -class NullShardingContext(Multistreamable): - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - pass - - -@dataclass class EmbeddingShardingContext(Multistreamable): - batch_size_per_rank: List[int] = field(default_factory=list) + # Torch Dynamo does not support default_factory=list: + # https://github.com/pytorch/pytorch/issues/120108 + # TODO(ivankobzarev) Make this a dataclass once supported - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def __init__( + self, + batch_size_per_rank: Optional[List[int]] = None, + batch_size_per_rank_per_feature: Optional[List[List[int]]] = None, + batch_size_per_feature_pre_a2a: Optional[List[int]] = None, + variable_batch_per_feature: bool = False, + ) -> None: + super().__init__() + self.batch_size_per_rank: List[int] = ( + batch_size_per_rank if batch_size_per_rank is not None else [] + ) + self.batch_size_per_rank_per_feature: List[List[int]] = ( + batch_size_per_rank_per_feature + if batch_size_per_rank_per_feature is not None + else [] + ) + self.batch_size_per_feature_pre_a2a: List[int] = ( + batch_size_per_feature_pre_a2a + if batch_size_per_feature_pre_a2a is not None + else [] + ) + self.variable_batch_per_feature: bool = variable_batch_per_feature + + def record_stream(self, stream: torch.Stream) -> None: pass @@ -640,8 +942,8 @@ class BaseSparseFeaturesDist(abc.ABC, nn.Module, Generic[F]): @abc.abstractmethod def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[F]]: + sparse_features: KeyedJaggedTensor, + ) -> Union[Awaitable[Awaitable[F]], F]: pass @@ -655,7 +957,7 @@ def forward( self, local_embs: T, sharding_ctx: Optional[C] = None, - ) -> Awaitable[W]: + ) -> Union[Awaitable[W], W]: pass @@ -666,9 +968,9 @@ class EmbeddingSharding(abc.ABC, Generic[C, F, T, W], FeatureShardingMixIn): """ def __init__( - self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None + self, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: - self._qcomm_codecs_registry = qcomm_codecs_registry @property @@ -714,6 +1016,15 @@ def embedding_names(self) -> List[str]: def embedding_names_per_rank(self) -> List[List[str]]: pass + def embedding_tables(self) -> List[ShardedEmbeddingTable]: + raise NotImplementedError + + def uncombined_embedding_dims(self) -> List[int]: + return self.embedding_dims() + + def uncombined_embedding_names(self) -> List[str]: + return self.embedding_names() + @dataclass class EmbeddingShardingInfo: diff --git a/torchrec/distributed/embedding_tower_sharding.py b/torchrec/distributed/embedding_tower_sharding.py index 1abe77075..2fdca6b5c 100644 --- a/torchrec/distributed/embedding_tower_sharding.py +++ b/torchrec/distributed/embedding_tower_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Type, TypeVar @@ -15,18 +17,16 @@ from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import intra_and_cross_node_pg from torchrec.distributed.dist_data import ( + KJTAllToAll, PooledEmbeddingsAllToAll, PooledEmbeddingsAwaitable, ) from torchrec.distributed.embedding import EmbeddingCollectionSharder -from torchrec.distributed.embedding_sharding import ( - SparseFeaturesAllToAll, - SparseFeaturesListAwaitable, -) +from torchrec.distributed.embedding_sharding import KJTListSplitsAwaitable from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, - SparseFeatures, - SparseFeaturesList, + KJTList, + ShardedEmbeddingModule, ) from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.types import ( @@ -37,7 +37,6 @@ NullShardedModuleContext, ParameterSharding, QuantizedCommCodecs, - ShardedModule, ShardingEnv, ShardingType, ) @@ -70,10 +69,6 @@ def _replace_sharding_with_intra_node( raise ValueError(f"Sharding type not supported {value.sharding_type}") if value.ranks: value.ranks = [rank % local_size for rank in value.ranks] - if value.sharding_spec: - # pyre-ignore [6, 16] - for (shard, rank) in zip(value.sharding_spec.shards, value.ranks): - shard.placement._rank = rank class TowerLazyAwaitable(LazyAwaitable[torch.Tensor]): @@ -92,14 +87,14 @@ def _wait_impl(self) -> torch.Tensor: class EmbeddingTowerCollectionContext(Multistreamable): embedding_contexts: List[NullShardedModuleContext] = field(default_factory=list) - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: for ctx in self.embedding_contexts: ctx.record_stream(stream) class ShardedEmbeddingTower( - ShardedModule[ - SparseFeaturesList, + ShardedEmbeddingModule[ + KJTList, torch.Tensor, torch.Tensor, NullShardedModuleContext, @@ -145,6 +140,7 @@ def __init__( self._wkjt_feature_names: List[str] = wkjt_features self._has_uninitialized_input_dist: bool = True self._cross_dist: nn.Module = nn.Module() + self._weighted_cross_dist: nn.Module = nn.Module() self._kjt_features_order: List[int] = [] self._wkjt_features_order: List[int] = [] self._has_kjt_features_permute: bool = False @@ -172,7 +168,7 @@ def __init__( # Hierarchical DDP self.interaction = DistributedDataParallel( module=module.interaction.to(self._device), - device_ids=[self._device], + device_ids=[self._device] if self._device is not None else None, process_group=self._intra_pg, gradient_as_bucket_view=True, broadcast_buffers=False, @@ -199,6 +195,7 @@ def _create_input_dist( torch.tensor( self._kjt_features_order, device=self._device, dtype=torch.int32 ), + persistent=False, ) if self._wkjt_feature_names != wkjt_feature_names: @@ -210,6 +207,7 @@ def _create_input_dist( torch.tensor( self._wkjt_features_order, device=self._device, dtype=torch.int32 ), + persistent=False, ) node_count = dist.get_world_size(self._cross_pg) @@ -221,13 +219,17 @@ def _create_input_dist( len(self._wkjt_feature_names) if node == self._tower_node else 0 for node in range(node_count) ] - self._cross_dist = SparseFeaturesAllToAll( + self._cross_dist = KJTAllToAll( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. self._cross_pg, kjt_features_per_node, + ) + self._weighted_cross_dist = KJTAllToAll( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._cross_pg, wkjt_features_per_node, - self._device, ) # pyre-ignore[14] @@ -236,7 +238,7 @@ def input_dist( ctx: NullShardedModuleContext, features: KeyedJaggedTensor, optional_features: Optional[KeyedJaggedTensor] = None, - ) -> Awaitable[SparseFeaturesList]: + ) -> Awaitable[Awaitable[KJTList]]: # optional_features are populated only if both kjt and weighted kjt present in tower if self._wkjt_feature_names and self._kjt_feature_names: @@ -271,30 +273,27 @@ def input_dist( self._wkjt_features_order, self._wkjt_features_order_tensor, ) - tensor_awaitable = self._cross_dist( - SparseFeatures( - id_list_features=kjt_features, - id_score_list_features=wkjt_features, - ) - ) - return SparseFeaturesListAwaitable([tensor_awaitable.wait()]) + + awaitables = [] + if kjt_features is not None: + awaitables.append(self._cross_dist(kjt_features)) + if wkjt_features is not None: + awaitables.append(self._weighted_cross_dist(wkjt_features)) + + return KJTListSplitsAwaitable(awaitables, ctx) def compute( - self, ctx: NullShardedModuleContext, dist_input: SparseFeaturesList + self, ctx: NullShardedModuleContext, dist_input: KJTList ) -> torch.Tensor: - kjt_features = dist_input[0].id_list_features - wkjt_features = dist_input[0].id_score_list_features - if self._active_device: - if kjt_features and wkjt_features: + if len(dist_input) == 2: + kjt_features = dist_input[0] + wkjt_features = dist_input[1] # pyre-ignore [29] embeddings = self.embedding(kjt_features, wkjt_features) - elif wkjt_features: - # pyre-ignore [29] - embeddings = self.embedding(wkjt_features) else: # pyre-ignore [29] - embeddings = self.embedding(kjt_features) + embeddings = self.embedding(dist_input[0]) # pyre-ignore [29] output = self.interaction(embeddings) else: @@ -339,11 +338,13 @@ def _create_output_dist( # `List[Union[bool, float, int]]`. dim_sum_per_rank=dim_sum_per_rank, device=self._device, - codecs=self.qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if self.qcomm_codecs_registry - else None, + codecs=( + self.qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if self.qcomm_codecs_registry + else None + ), ) def output_dist( @@ -378,7 +379,7 @@ def state_dict( @property def fused_optimizer(self) -> KeyedOptimizer: if self.embedding: - # pyre-ignore [7] + # pyre-fixme[7]: Expected `KeyedOptimizer` but got `Union[Module, Tensor]`. return self.embedding.fused_optimizer else: return CombinedOptimizer([]) @@ -412,19 +413,6 @@ def named_buffers( ) yield from () - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - if self._active_device: - # pyre-ignore[16] - self.embedding.sparse_grad_parameter_names( - destination, append_prefix(prefix, "embedding") - ) - return destination - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: if self._active_device: # pyre-ignore[16] @@ -452,8 +440,8 @@ def create_context(self) -> NullShardedModuleContext: class ShardedEmbeddingTowerCollection( - ShardedModule[ - SparseFeaturesList, + ShardedEmbeddingModule[ + KJTList, torch.Tensor, torch.Tensor, EmbeddingTowerCollectionContext, @@ -496,6 +484,7 @@ def __init__( self.interactions: nn.ModuleDict = nn.ModuleDict() self.input_dist_params: List[Tuple[bool, bool]] = [] self._cross_dist: nn.Module = nn.Module() + self._weighted_cross_dist: nn.Module = nn.Module() # groups parameter sharding into physical towers tables_per_pt: List[Set[str]] = [ @@ -574,10 +563,29 @@ def __init__( pg=self._intra_pg, ) for i, tower in local_towers: + table_names = {} + if isinstance(tower.embedding, EmbeddingBagCollection): + table_names = { + table.name for table in tower.embedding.embedding_bag_configs() + } + elif isinstance(tower.embedding, EmbeddingCollection): + table_names = { + table.name for table in tower.embedding.embedding_configs() + } + elif hasattr(tower.embedding, "tables"): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + table_names = {table.name for table in tower.embedding.tables()} + else: + # Use all tables if unable to determine from tower.embedding + table_names = set(table_name_to_parameter_sharding.keys()) # pyre-ignore [16] self.embeddings[i] = tower_sharder.embedding_sharder(tower).shard( tower.embedding, - table_name_to_parameter_sharding, + { + table: param + for table, param in table_name_to_parameter_sharding.items() + if table in table_names + }, intra_env, device, ) @@ -585,7 +593,7 @@ def __init__( # Hierarchical DDP self.interactions[i] = DistributedDataParallel( module=tower.interaction.to(self._device), - device_ids=[self._device], + device_ids=[self._device] if self._device is not None else None, process_group=self._intra_pg, gradient_as_bucket_view=True, broadcast_buffers=False, @@ -625,13 +633,18 @@ def _create_input_dist( self._wkjt_features_order, device=self._device, dtype=torch.int32 ), ) - self._cross_dist = SparseFeaturesAllToAll( + + self._cross_dist = KJTAllToAll( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. self._cross_pg, self._kjt_num_features_per_pt, + ) + self._weighted_cross_dist = KJTAllToAll( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + self._cross_pg, self._wkjt_num_features_per_pt, - self._device, ) # pyre-ignore [14] @@ -640,7 +653,7 @@ def input_dist( ctx: EmbeddingTowerCollectionContext, kjt_features: Optional[KeyedJaggedTensor] = None, wkjt_features: Optional[KeyedJaggedTensor] = None, - ) -> Awaitable[SparseFeaturesList]: + ) -> Awaitable[Awaitable[KJTList]]: if self._has_uninitialized_input_dist: # pyre-ignore [16] stride = kjt_features.stride() if kjt_features else wkjt_features.stride() @@ -661,58 +674,29 @@ def input_dist( self._wkjt_features_order, cast(torch.Tensor, self._wkjt_features_order_tensor), ) - sparse_features_awaitable = self._cross_dist( - SparseFeatures( - id_list_features=kjt_features, - id_score_list_features=wkjt_features, - ) - ) - - sparse_features = sparse_features_awaitable.wait().wait() + awaitables = [] + if kjt_features is not None: + awaitables.append(self._cross_dist(kjt_features)) + if wkjt_features is not None: + awaitables.append(self._weighted_cross_dist(wkjt_features)) + return KJTListSplitsAwaitable(awaitables, ctx) - input_dists = [] + def compute( + self, ctx: EmbeddingTowerCollectionContext, dist_input: KJTList + ) -> torch.Tensor: + if self.embeddings: + embeddings = [] for embedding, input_dist_params in zip( self.embeddings.values(), self.input_dist_params ): - - embedding_ctx = embedding.create_context() - ctx.embedding_contexts.append(embedding_ctx) kjt_param, wkjt_param = input_dist_params if kjt_param and wkjt_param: - input_dists.append( - embedding.input_dist( - embedding_ctx, - sparse_features.id_list_features, - sparse_features.id_score_list_features, - ) - ) - elif kjt_param: - input_dists.append( - embedding.input_dist( - embedding_ctx, sparse_features.id_list_features - ) - ) + assert len(dist_input) == 2 + embeddings.append(embedding(dist_input[0], dist_input[1])) + elif wkjt_param and len(dist_input) == 2: + embeddings.append(embedding(dist_input[1])) else: - input_dists.append( - embedding.input_dist( - embedding_ctx, sparse_features.id_score_list_features - ) - ) - return SparseFeaturesListAwaitable(input_dists) - - def compute( - self, ctx: EmbeddingTowerCollectionContext, dist_input: SparseFeaturesList - ) -> torch.Tensor: - - if self.embeddings: - embeddings = [ - embedding.compute_and_output_dist(embedding_ctx, embedding_input) - for embedding_ctx, embedding, embedding_input in zip( - ctx.embedding_contexts, - self.embeddings.values(), - dist_input, - ) - ] + embeddings.append(embedding(dist_input[0])) output = torch.cat( [ interaction(embedding) @@ -723,7 +707,6 @@ def compute( ], dim=1, ) - else: output = torch.empty( [self._cross_pg_global_batch_size, 0], @@ -764,11 +747,13 @@ def _create_output_dist(self, output: torch.Tensor) -> None: # pyre-ignore dim_sum_per_rank=dim_sum_per_rank, device=self._device, - codecs=self.qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if self.qcomm_codecs_registry - else None, + codecs=( + self.qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if self.qcomm_codecs_registry + else None + ), ) def output_dist( @@ -845,18 +830,6 @@ def named_buffers( ) ) - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - for i, embedding in self.embeddings.items(): - embedding.sparse_grad_parameter_names( - destination, append_prefix(prefix, f"towers.{i}.embedding") - ) - return destination - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for i, embedding in self.embeddings.items(): yield from ( @@ -888,6 +861,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingTower: kjt_features, wkjt_features = self.embedding_feature_names(module) @@ -988,6 +962,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingTowerCollection: return ShardedEmbeddingTowerCollection( diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 5d01c3eff..e8da4a6da 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -5,18 +5,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc +import copy from dataclasses import dataclass from enum import Enum, unique -from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar +from typing import ( + Any, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch -from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation +from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from torch import fx, nn +from torch.distributed._tensor import DeviceMesh +from torch.distributed._tensor.placement_types import Placement +from torch.nn.modules.module import _addindent +from torch.nn.parallel import DistributedDataParallel + from torchrec.distributed.types import ( + get_tensor_size_bytes, ModuleSharder, ParameterStorage, QuantizedCommCodecs, + ShardedModule, ShardedTensorMetadata, ShardingType, ShardMetadata, @@ -41,6 +62,11 @@ class OptimType(Enum): ADAGRAD = "ADAGRAD" ROWWISE_ADAGRAD = "ROWWISE_ADAGRAD" SHAMPOO = "SHAMPOO" + SHAMPOO_V2 = "SHAMPOO_V2" + LION = "LION" + ADAMW = "ADAMW" + SHAMPOO_V2_MRS = "SHAMPOO_V2_MRS" + SHAMPOO_MRS = "SHAMPOO_MRS" @unique @@ -52,6 +78,8 @@ class EmbeddingComputeKernel(Enum): QUANT = "quant" QUANT_UVM = "quant_uvm" QUANT_UVM_CACHING = "quant_uvm_caching" + KEY_VALUE = "key_value" + CUSTOMIZED_KERNEL = "customized_kernel" def compute_kernel_to_embedding_location( @@ -61,6 +89,7 @@ def compute_kernel_to_embedding_location( EmbeddingComputeKernel.DENSE, EmbeddingComputeKernel.FUSED, EmbeddingComputeKernel.QUANT, + EmbeddingComputeKernel.KEY_VALUE, # use hbm for cache ]: return EmbeddingLocation.DEVICE elif compute_kernel in [ @@ -77,82 +106,85 @@ def compute_kernel_to_embedding_location( raise ValueError(f"Invalid EmbeddingComputeKernel {compute_kernel}") -@dataclass -class SparseFeatures(Multistreamable): - id_list_features: Optional[KeyedJaggedTensor] = None - id_score_list_features: Optional[KeyedJaggedTensor] = None - - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - if self.id_list_features is not None: - self.id_list_features.record_stream(stream) - if self.id_score_list_features is not None: - self.id_score_list_features.record_stream(stream) - - def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument: - return tracer.create_node( - "call_function", - SparseFeatures, - args=( - tracer.create_arg(self.id_list_features), - tracer.create_arg(self.id_score_list_features), - ), - kwargs={}, - ) - - -class SparseFeaturesList(Multistreamable): - def __init__(self, features: List[SparseFeatures]) -> None: +class KJTList(Multistreamable): + def __init__(self, features: List[KeyedJaggedTensor]) -> None: self.features = features def __len__(self) -> int: return len(self.features) - def __setitem__(self, key: int, item: SparseFeatures) -> None: + def __setitem__(self, key: int, item: KeyedJaggedTensor) -> None: self.features[key] = item - def __getitem__(self, key: int) -> SparseFeatures: + def __getitem__(self, key: int) -> KeyedJaggedTensor: return self.features[key] - def __iter__(self) -> Iterator[SparseFeatures]: + @torch.jit._drop + def __iter__(self) -> Iterator[KeyedJaggedTensor]: return iter(self.features) + @torch.jit._drop def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for feature in self.features: feature.record_stream(stream) + @torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument: return tracer.create_node( "call_function", - SparseFeaturesList, + KJTList, args=(tracer.create_arg(self.features),), kwargs={}, ) -class ListOfSparseFeaturesList(Multistreamable): - def __init__(self, features: List[SparseFeaturesList]) -> None: +@dataclass +class InputDistOutputs(Multistreamable): + features: KJTList + unbucketize_permute_tensor: Optional[torch.Tensor] = ( + None # only used in RW sharding + ) + bucket_mapping_tensor: Optional[torch.Tensor] = None # only used in RW sharding + bucketized_length: Optional[torch.Tensor] = None # only used in RW sharding + + def record_stream(self, stream: torch.Stream) -> None: + for feature in self.features: + feature.record_stream(stream) + if self.unbucketize_permute_tensor is not None: + self.unbucketize_permute_tensor.record_stream(stream) + if self.bucket_mapping_tensor is not None: + self.bucket_mapping_tensor.record_stream(stream) + if self.bucketized_length is not None: + self.bucketized_length.record_stream(stream) + + +class ListOfKJTList(Multistreamable): + def __init__(self, features: List[KJTList]) -> None: self.features_list = features def __len__(self) -> int: return len(self.features_list) - def __setitem__(self, key: int, item: SparseFeaturesList) -> None: + def __setitem__(self, key: int, item: KJTList) -> None: self.features_list[key] = item - def __getitem__(self, key: int) -> SparseFeaturesList: + def __getitem__(self, key: int) -> KJTList: return self.features_list[key] - def __iter__(self) -> Iterator[SparseFeaturesList]: + @torch.jit._drop + def __iter__(self) -> Iterator[KJTList]: return iter(self.features_list) - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + @torch.jit._drop + def record_stream(self, stream: torch.Stream) -> None: for feature in self.features_list: feature.record_stream(stream) + @torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument: return tracer.create_node( "call_function", - ListOfSparseFeaturesList, + ListOfKJTList, args=(tracer.create_arg(self.features_list),), kwargs={}, ) @@ -164,10 +196,19 @@ class ShardedConfig: local_cols: int = 0 +@dataclass +class DTensorMetadata: + mesh: Optional[DeviceMesh] = None + placements: Optional[Tuple[Placement, ...]] = None + size: Optional[Tuple[int, ...]] = None + stride: Optional[Tuple[int, ...]] = None + + @dataclass class ShardedMetaConfig(ShardedConfig): local_metadata: Optional[ShardMetadata] = None global_metadata: Optional[ShardedTensorMetadata] = None + dtensor_metadata: Optional[DTensorMetadata] = None @dataclass @@ -212,6 +253,12 @@ def dim_sum(self) -> int: dim_sum += table.num_features() * table.local_cols return dim_sum + def table_names(self) -> List[str]: + table_names = [] + for table in self.embedding_tables: + table_names.append(table.name) + return table_names + def feature_names(self) -> List[str]: feature_names = [] for table in self.embedding_tables: @@ -255,11 +302,116 @@ def forward( ) -> T: pass - def sparse_grad_parameter_names( - self, destination: Optional[List[str]] = None, prefix: str = "" - ) -> List[str]: - destination = [] if destination is None else destination - return destination + +class FeatureShardingMixIn: + """ + Feature Sharding Interface to provide sharding-aware feature metadata. + """ + + def feature_names(self) -> List[str]: + raise NotImplementedError + + def feature_names_per_rank(self) -> List[List[str]]: + raise NotImplementedError + + def features_per_rank(self) -> List[int]: + raise NotImplementedError + + +class ModuleShardingMixIn: + """ + The interface to access a sharded module's sharding scheme. + """ + + @property + def shardings(self) -> Dict[str, FeatureShardingMixIn]: + raise NotImplementedError + + +Out = TypeVar("Out") +CompIn = TypeVar("CompIn") +DistOut = TypeVar("DistOut") +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + +class ShardedEmbeddingModule( + ShardedModule[CompIn, DistOut, Out, ShrdCtx], + ModuleShardingMixIn, +): + """ + All model-parallel embedding modules implement this interface. + Inputs and outputs are data-parallel. + + Args:: + qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs + """ + + @abc.abstractmethod + def __init__( + self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None + ) -> None: + super().__init__(qcomm_codecs_registry) + + self._input_dists: List[nn.Module] = [] + self._lookups: List[nn.Module] = [] + self._output_dists: List[nn.Module] = [] + + def prefetch( + self, + dist_input: KJTList, + forward_stream: Optional[Union[torch.cuda.Stream, torch.mtia.Stream]] = None, + ctx: Optional[ShrdCtx] = None, + ) -> None: + """ + Prefetch input features for each lookup module. + """ + + for feature, emb_lookup in zip(dist_input, self._lookups): + while isinstance(emb_lookup, DistributedDataParallel): + emb_lookup = emb_lookup.module + emb_lookup.prefetch(sparse_features=feature, forward_stream=forward_stream) + + def extra_repr(self) -> str: + """ + Pretty prints representation of the module's lookup modules, input_dists and output_dists + """ + + def loop(key: str, modules: List[nn.Module]) -> List[str]: + child_lines = [] + if len(modules) > 0: + child_lines.append("(" + key + "): ") + for module in modules: + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append(mod_str) + return child_lines + + rep = [] + rep.extend(loop("lookups", self._lookups)) + rep.extend(loop("_input_dists", self._input_dists)) + rep.extend(loop("_output_dists", self._output_dists)) + + return "\n ".join(rep) + + def train(self, mode: bool = True): # pyre-ignore[3] + r"""Set the module in training mode.""" + super().train(mode) + + # adding additional handling for lookups + for lookup in self._lookups: + lookup.train(mode) + + return self + + @property + def unsharded_module_type(self) -> Type[nn.Module]: + """ + As this is the generic ShardedEmbeddingModule class, simply + return the generic nn.Module type. In the inherited classes of + ShardedEmbeddingModule, the specific unsharded module type will + be returned. + """ + return nn.Module M = TypeVar("M", bound=nn.Module) @@ -270,15 +422,24 @@ def __init__( self, fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._variable_batch_size = variable_batch_size # TODO remove after decoupling self._fused_params = fused_params def sharding_types(self, compute_device_type: str) -> List[str]: + + if compute_device_type in {"mtia"}: + return [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, + ] + types = [ ShardingType.DATA_PARALLEL.value, ShardingType.TABLE_WISE.value, @@ -289,6 +450,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]: types += [ ShardingType.ROW_WISE.value, ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, ] return types @@ -299,9 +461,7 @@ def compute_kernels( sharding_type: str, compute_device_type: str, ) -> List[str]: - ret = [ - EmbeddingComputeKernel.DENSE.value, - ] + ret: List[str] = [] if sharding_type != ShardingType.DATA_PARALLEL.value: ret += [ EmbeddingComputeKernel.FUSED.value, @@ -310,17 +470,19 @@ def compute_kernels( ret += [ EmbeddingComputeKernel.FUSED_UVM.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.KEY_VALUE.value, ] + else: + # TODO re-enable model parallel and dense + ret += [ + EmbeddingComputeKernel.DENSE.value, + ] return ret @property def fused_params(self) -> Optional[Dict[str, Any]]: return self._fused_params - @property - def variable_batch_size(self) -> bool: - return self._variable_batch_size - def storage_usage( self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str ) -> Dict[str, int]: @@ -328,7 +490,7 @@ def storage_usage( List of system resources and corresponding usage given a compute device and compute kernel """ - tensor_bytes = tensor.element_size() * tensor.nelement() + tensor_bytes = get_tensor_size_bytes(tensor) if compute_kernel in { EmbeddingComputeKernel.FUSED_UVM.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, @@ -336,11 +498,15 @@ def storage_usage( assert compute_device_type in {"cuda"} return {ParameterStorage.DDR.value: tensor_bytes} else: - assert compute_device_type in {"cuda", "cpu"} - storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + assert compute_device_type in {"cuda", "cpu", "mtia"} + storage_map = { + "cuda": ParameterStorage.HBM, + "cpu": ParameterStorage.DDR, + # TODO: Update it later. Setting for MTIA is same as CPU's for now. + "mtia": ParameterStorage.DDR, + } return { - storage_map[compute_device_type].value: tensor.element_size() - * tensor.nelement() + storage_map[compute_device_type].value: get_tensor_size_bytes(tensor) } @@ -356,12 +522,6 @@ def forward( ) -> KeyedJaggedTensor: pass - def sparse_grad_parameter_names( - self, destination: Optional[List[str]] = None, prefix: str = "" - ) -> List[str]: - destination = [] if destination is None else destination - return destination - class BaseQuantEmbeddingSharder(ModuleSharder[M]): def __init__( @@ -370,7 +530,9 @@ def __init__( shardable_params: Optional[List[str]] = None, ) -> None: super().__init__() - self._fused_params = fused_params + self._fused_params: Optional[Dict[str, Any]] = ( + copy.deepcopy(fused_params) if fused_params is not None else fused_params + ) if not shardable_params: shardable_params = [] self._shardable_params: List[str] = shardable_params @@ -378,6 +540,8 @@ def __init__( def sharding_types(self, compute_device_type: str) -> List[str]: types = [ ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, ] return types @@ -392,15 +556,11 @@ def shardable_parameters(self, module: M) -> Dict[str, nn.Parameter]: if self._shardable_params: assert all( - [ - table_name in self._shardable_params - for table_name in shardable_params.keys() - ] + table_name in self._shardable_params + for table_name in shardable_params.keys() ) or all( - [ - table_name not in self._shardable_params - for table_name in shardable_params.keys() - ] + table_name not in self._shardable_params + for table_name in shardable_params.keys() ), f"Cannot partially shard {type(module)}, please check sharder kwargs" shardable_params = { table_name: param @@ -434,7 +594,7 @@ def storage_usage( List of system resources and corresponding usage given a compute device and compute kernel """ - tensor_bytes = tensor.element_size() * tensor.nelement() + tensor.shape[0] * 4 + tensor_bytes = get_tensor_size_bytes(tensor) + tensor.shape[0] * 4 if compute_kernel in { EmbeddingComputeKernel.QUANT_UVM.value, EmbeddingComputeKernel.QUANT_UVM_CACHING.value, @@ -442,6 +602,11 @@ def storage_usage( assert compute_device_type in {"cuda"} return {ParameterStorage.DDR.value: tensor_bytes} else: - assert compute_device_type in {"cuda", "cpu"} - storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + assert compute_device_type in {"cuda", "cpu", "mtia"} + storage_map = { + "cuda": ParameterStorage.HBM, + "cpu": ParameterStorage.DDR, + # TODO: Update it later. Setting for MTIA is same as CPU's for now. + "mtia": ParameterStorage.DDR, + } return {storage_map[compute_device_type].value: tensor_bytes} diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 21973450e..ed3478ff8 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -5,60 +5,143 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy -from collections import OrderedDict +from collections import defaultdict, OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union +from functools import partial +from typing import ( + Any, + cast, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) import torch -from torch import nn, Tensor +from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + DenseTableBatchedEmbeddingBagsCodegen, +) +from tensordict import TensorDict +from torch import distributed as dist, nn, Tensor +from torch.autograd.profiler import record_function +from torch.distributed._shard.sharded_tensor import TensorProperties +from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, EmbeddingShardingInfo, + KJTListSplitsAwaitable, Multistreamable, - SparseFeaturesIndicesAwaitable, - SparseFeaturesListAwaitable, ) from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, - SparseFeatures, - SparseFeaturesList, + KJTList, + ShardedEmbeddingModule, ) from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding +from torchrec.distributed.sharding.dynamic_sharding import ( + shards_all_to_all, + update_module_sharding_plan, + update_state_dict_post_resharding, +) +from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding from torchrec.distributed.sharding.twcw_sharding import TwCwPooledEmbeddingSharding from torchrec.distributed.sharding.twrw_sharding import TwRwPooledEmbeddingSharding +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( Awaitable, + EmbeddingEvent, + EmbeddingModuleShardingPlan, EnumerableShardingSpec, LazyAwaitable, + LazyGetItemMixin, NullShardedModuleContext, ParameterSharding, QuantizedCommCodecs, - ShardedModule, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, + ShardMetadata, ) from torchrec.distributed.utils import ( + add_params_from_parameter_sharding, append_prefix, - filter_state_dict, + convert_to_fbgemm_types, + create_global_tensor_shape_stride_from_metadata, + maybe_annotate_embedding_event, merge_fused_params, + none_throws, optimizer_type_to_emb_opt_type, ) -from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType +from torchrec.modules.embedding_configs import ( + data_type_to_dtype, + EmbeddingBagConfig, + EmbeddingTableConfig, + PoolingType, +) from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingBagCollectionInterface, ) -from torchrec.optim.fused import FusedOptimizerModule +from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") +except OSError: + pass + + +def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: + return ( + tensor + if device.type == "cpu" + else tensor.pin_memory().to(device=device, non_blocking=True) + ) + + +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> Union[str, Tuple[str, ...]]: + """ + Returns list of device type per shard if table is sharded across different + device type, else reutrns single device type for the table parameter + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] + else: + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list def replace_placement_with_meta_device( @@ -88,79 +171,13 @@ def replace_placement_with_meta_device( ) -def create_embedding_bag_sharding( - sharding_type: str, - sharding_infos: List[EmbeddingShardingInfo], - env: ShardingEnv, - device: Optional[torch.device] = None, - permute_embeddings: bool = False, - need_pos: bool = False, - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, -) -> EmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor -]: - if device is not None and device.type == "meta": - replace_placement_with_meta_device(sharding_infos) - if sharding_type == ShardingType.TABLE_WISE.value: - return TwPooledEmbeddingSharding( - sharding_infos, - env, - device, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.ROW_WISE.value: - return RwPooledEmbeddingSharding( - sharding_infos, - env, - device, - need_pos=need_pos, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.DATA_PARALLEL.value: - return DpPooledEmbeddingSharding(sharding_infos, env, device) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: - return TwRwPooledEmbeddingSharding( - sharding_infos, - env, - device, - need_pos=need_pos, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.COLUMN_WISE.value: - return CwPooledEmbeddingSharding( - sharding_infos, - env, - device, - permute_embeddings=permute_embeddings, - qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, - ) - elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: - if variable_batch_size: - raise ValueError( - f"Variable batch size not supported for sharding type {sharding_type}" - ) - return TwCwPooledEmbeddingSharding( - sharding_infos, - env, - device, - permute_embeddings=permute_embeddings, - qcomm_codecs_registry=qcomm_codecs_registry, - ) - else: - raise ValueError(f"Sharding type not supported {sharding_type}") - - -def create_sharding_infos_by_sharding( +def create_sharding_infos_by_sharding_device_group( module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, fused_params: Optional[Dict[str, Any]], -) -> Dict[str, List[EmbeddingShardingInfo]]: + suffix: Optional[str] = "weight", +) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: if fused_params is None: fused_params = {} @@ -175,7 +192,9 @@ def create_sharding_infos_by_sharding( else: shared_feature[feature_name] = True - sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + sharding_type_device_group_to_sharding_infos: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] + ] = {} # state_dict returns parameter.Tensor, which loses parameter level attributes parameter_by_name = dict(module.named_parameters()) @@ -184,7 +203,9 @@ def create_sharding_infos_by_sharding( for config in module.embedding_bag_configs(): table_name = config.name - assert table_name in table_name_to_parameter_sharding + assert ( + table_name in table_name_to_parameter_sharding + ), f"{table_name} not in table_name_to_parameter_sharding" parameter_sharding = table_name_to_parameter_sharding[table_name] if parameter_sharding.compute_kernel not in [ kernel.value for kernel in EmbeddingComputeKernel @@ -199,22 +220,47 @@ def create_sharding_infos_by_sharding( else: embedding_names.append(feature_name) - param_name = prefix + table_name + ".weight" + param_name = prefix + table_name + if suffix is not None: + param_name = f"{param_name}.{suffix}" + assert param_name in parameter_by_name or param_name in state_dict param = parameter_by_name.get(param_name, state_dict[param_name]) - if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos: - sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = [] + device_group = get_device_from_parameter_sharding(parameter_sharding) + + if ( + parameter_sharding.sharding_type, + device_group, + ) not in sharding_type_device_group_to_sharding_infos: + sharding_type_device_group_to_sharding_infos[ + (parameter_sharding.sharding_type, device_group) + ] = [] + + optimizer_params = getattr(param, "_optimizer_kwargs", [{}]) + optimizer_classes = getattr(param, "_optimizer_classes", [None]) + + assert ( + len(optimizer_classes) == 1 and len(optimizer_params) == 1 + ), f"Only support 1 optimizer, given {len(optimizer_classes)} optimizer classes \ + and {len(optimizer_params)} optimizer kwargs." - optimizer_params = getattr(param, "_optimizer_kwargs", {}) - optimizer_class = getattr(param, "_optimizer_class", None) + optimizer_class = optimizer_classes[0] + optimizer_params = optimizer_params[0] if optimizer_class: optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type( optimizer_class ) - fused_params = merge_fused_params(fused_params, optimizer_params) - sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( + per_table_fused_params = merge_fused_params(fused_params, optimizer_params) + per_table_fused_params = add_params_from_parameter_sharding( + per_table_fused_params, parameter_sharding + ) + per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + + sharding_type_device_group_to_sharding_infos[ + (parameter_sharding.sharding_type, device_group) + ].append( EmbeddingShardingInfo( embedding_config=EmbeddingTableConfig( num_embeddings=config.num_embeddings, @@ -228,45 +274,127 @@ def create_sharding_infos_by_sharding( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, + num_embeddings_post_pruning=( + getattr(config, "num_embeddings_post_pruning", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, - fused_params=fused_params, + fused_params=per_table_fused_params, ) ) - return sharding_type_to_sharding_infos + return sharding_type_device_group_to_sharding_infos -def _check_need_pos(module: EmbeddingBagCollectionInterface) -> bool: - for config in module.embedding_bag_configs(): - if config.need_pos: - return True - return False +def construct_output_kt( + embeddings: List[torch.Tensor], + embedding_names: List[str], + embedding_dims: List[int], +) -> KeyedTensor: + cat_embeddings: torch.Tensor + if len(embeddings) == 1: + cat_embeddings = embeddings[0] + else: + cat_embeddings = torch.cat(embeddings, dim=1) + return KeyedTensor( + keys=embedding_names, + length_per_key=embedding_dims, + values=cat_embeddings, + key_dim=1, + ) -class EmbeddingBagCollectionAwaitable(LazyAwaitable[KeyedTensor]): +class VariableBatchEmbeddingBagCollectionAwaitable( + LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[KeyedTensor] +): + def __init__( + self, + awaitables: List[Awaitable[torch.Tensor]], + inverse_indices: Tuple[List[str], torch.Tensor], + inverse_indices_permute_indices: Optional[torch.Tensor], + batch_size_per_feature_pre_a2a: List[int], + uncombined_embedding_dims: List[int], + embedding_names: List[str], + embedding_dims: List[int], + permute_op: PermutePooledEmbeddings, + module_fqn: Optional[str] = None, + sharding_types: Optional[List[str]] = None, + ) -> None: + super().__init__() + self._awaitables = awaitables + self._inverse_indices = inverse_indices + self._inverse_indices_permute_indices = inverse_indices_permute_indices + self._batch_size_per_feature_pre_a2a = batch_size_per_feature_pre_a2a + self._uncombined_embedding_dims = uncombined_embedding_dims + self._embedding_names = embedding_names + self._embedding_dims = embedding_dims + self._permute_op = permute_op + self._module_fqn = module_fqn + self._sharding_types = sharding_types + + def _wait_impl(self) -> KeyedTensor: + embeddings = [] + for i, w in enumerate(self._awaitables): + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST_WAIT, + self._module_fqn, + self._sharding_types[i] if self._sharding_types else None, + ): + embeddings.append(w.wait()) + batch_size = self._inverse_indices[1].numel() // len(self._inverse_indices[0]) + permute_indices = self._inverse_indices_permute_indices + if permute_indices is not None: + indices = torch.index_select(self._inverse_indices[1], 0, permute_indices) + else: + indices = self._inverse_indices[1] + reindex_output = torch.ops.fbgemm.batch_index_select_dim0( + inputs=embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings), + indices=indices.view(-1), + input_num_indices=[batch_size] * len(self._uncombined_embedding_dims), + input_rows=self._batch_size_per_feature_pre_a2a, + input_columns=self._uncombined_embedding_dims, + permute_output_dim_0_1=True, + ).view(batch_size, -1) + return construct_output_kt( + embeddings=[self._permute_op(reindex_output)], + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, + ) + + +class EmbeddingBagCollectionAwaitable( + LazyGetItemMixin[str, Tensor], LazyAwaitable[KeyedTensor] +): def __init__( self, awaitables: List[Awaitable[torch.Tensor]], embedding_dims: List[int], embedding_names: List[str], + module_fqn: Optional[str] = None, + sharding_types: Optional[List[str]] = None, ) -> None: super().__init__() self._awaitables = awaitables self._embedding_dims = embedding_dims self._embedding_names = embedding_names + self._module_fqn = module_fqn + self._sharding_types = sharding_types def _wait_impl(self) -> KeyedTensor: - embeddings = [w.wait() for w in self._awaitables] - if len(embeddings) == 1: - embeddings = embeddings[0] - else: - embeddings = torch.cat(embeddings, dim=1) - return KeyedTensor( - keys=self._embedding_names, - length_per_key=self._embedding_dims, - values=embeddings, - key_dim=1, + embeddings = [] + for i, w in enumerate(self._awaitables): + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST_WAIT, + self._module_fqn, + self._sharding_types[i] if self._sharding_types else None, + ): + embeddings.append(w.wait()) + + return construct_output_kt( + embeddings=embeddings, + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, ) @@ -275,23 +403,30 @@ class EmbeddingBagCollectionContext(Multistreamable): sharding_contexts: List[Optional[EmbeddingShardingContext]] = field( default_factory=list ) + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None + variable_batch_per_feature: bool = False + divisor: Optional[torch.Tensor] = None - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: for ctx in self.sharding_contexts: if ctx: ctx.record_stream(stream) + if self.inverse_indices is not None: + self.inverse_indices[1].record_stream(stream) + if self.divisor is not None: + self.divisor.record_stream(stream) class ShardedEmbeddingBagCollection( - ShardedModule[ - SparseFeaturesList, + ShardedEmbeddingModule[ + KJTList, List[torch.Tensor], KeyedTensor, EmbeddingBagCollectionContext, ], + # TODO remove after compute_kernel X sharding decoupling FusedOptimizerModule, ): - # TODO remove after compute_kernel X sharding decoupling """ Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining. @@ -305,48 +440,91 @@ def __init__( fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, + module_fqn: Optional[str] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( - module, - table_name_to_parameter_sharding, - "embedding_bags.", - fused_params, + self._module_fqn = module_fqn + self._embedding_bag_configs: List[EmbeddingBagConfig] = ( + module.embedding_bag_configs() ) - need_pos = _check_need_pos(module) - self._sharding_type_to_sharding: Dict[ - str, + + self._table_names: List[str] = [] + self._pooling_type_to_rs_features: Dict[str, List[str]] = defaultdict(list) + self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {} + + for config in self._embedding_bag_configs: + self._table_names.append(config.name) + self._table_name_to_config[config.name] = config + + if table_name_to_parameter_sharding[config.name].sharding_type in [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ]: + self._pooling_type_to_rs_features[config.pooling.value].extend( + config.feature_names + ) + + self.module_sharding_plan: EmbeddingModuleShardingPlan = cast( + EmbeddingModuleShardingPlan, + { + table_name: parameter_sharding + for table_name, parameter_sharding in table_name_to_parameter_sharding.items() + if table_name in self._table_names + }, + ) + self._env = env + # output parameters as DTensor in state dict + self._output_dtensor: bool = env.output_dtensor + self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + self.create_grouped_sharding_infos( + module, + table_name_to_parameter_sharding, + "embedding_bags.", + fused_params, + ) + ) + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() + ) + self._embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, - SparseFeatures, + KeyedJaggedTensor, torch.Tensor, torch.Tensor, - ], - ] = { - sharding_type: create_embedding_bag_sharding( - sharding_type, + ] + ] = [ + self.create_embedding_bag_sharding( embedding_configs, env, device, permute_embeddings=True, - need_pos=need_pos, qcomm_codecs_registry=self.qcomm_codecs_registry, - variable_batch_size=variable_batch_size, ) - for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items() - } + for embedding_configs in self.sharding_type_to_sharding_infos.values() + ] self._is_weighted: bool = module.is_weighted() self._device = device - self._input_dists = nn.ModuleList() - self._lookups: nn.ModuleList = nn.ModuleList() + self._input_dists: List[nn.Module] = [] + self._lookups: List[nn.Module] = [] self._create_lookups() - self._output_dists: nn.ModuleList = nn.ModuleList() + self._output_dists: List[nn.Module] = [] self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] self._feature_splits: List[int] = [] self._features_order: List[int] = [] + self._uncombined_embedding_names: List[str] = [] + self._uncombined_embedding_dims: List[int] = [] + self._inverse_indices_permute_indices: Optional[torch.Tensor] = None + # to support mean pooling callback hook + self._has_mean_pooling_callback: bool = ( + PoolingType.MEAN.value in self._pooling_type_to_rs_features + ) + self._dim_per_key: Optional[torch.Tensor] = None + self._kjt_key_indices: Dict[str, int] = {} + self._kjt_inverse_order: Optional[torch.Tensor] = None + self._kt_key_ordering: Optional[torch.Tensor] = None # to support the FP16 hook self._create_output_dist() @@ -356,36 +534,598 @@ def __init__( # Get all fused optimizers and combine them. optims = [] for lookup in self._lookups: - for _, module in lookup.named_modules(): - if isinstance(module, FusedOptimizerModule): + for _, tbe_module in lookup.named_modules(): + if isinstance(tbe_module, FusedOptimizerModule): # modify param keys to match EmbeddingBagCollection params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} - for param_key, weight in module.fused_optimizer.params.items(): - # pyre-fixme[16]: `Mapping` has no attribute `__setitem__`. + for param_key, weight in tbe_module.fused_optimizer.params.items(): + # pyre-fixme[16]: `Mapping` has no attribute `__setitem__` params["embedding_bags." + param_key] = weight - module.fused_optimizer.params = params - optims.append(("", module.fused_optimizer)) + tbe_module.fused_optimizer.params = params + optims.append(("", tbe_module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + for i, (sharding, lookup) in enumerate( + zip(self._embedding_shardings, self._lookups) + ): + # TODO: can move this into DpPooledEmbeddingSharding once all modules are composable + if isinstance(sharding, DpPooledEmbeddingSharding): + self._lookups[i] = DistributedDataParallel( + module=lookup, + device_ids=( + [self._device] + if self._device is not None + and (self._device.type in {"cuda", "mtia"}) + else None + ), + process_group=env.process_group, + gradient_as_bucket_view=True, + broadcast_buffers=True, + static_graph=True, + ) + + if env.process_group and dist.get_backend(env.process_group) != "fake": + self._initialize_torch_state() + + if module.device not in ["meta", "cpu"] and module.device.type not in [ + "meta", + "cpu", + ]: + self.load_state_dict(module.state_dict(), strict=False) + + @classmethod + def create_grouped_sharding_infos( + cls, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + prefix: str, + fused_params: Optional[Dict[str, Any]], + suffix: Optional[str] = "weight", + ) -> Dict[str, List[EmbeddingShardingInfo]]: + """ + convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to + EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters + """ + + if fused_params is None: + fused_params = {} + + shared_feature: Dict[str, bool] = {} + for embedding_config in module.embedding_bag_configs(): + if not embedding_config.feature_names: + embedding_config.feature_names = [embedding_config.name] + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + defaultdict(list) + ) + + # state_dict returns parameter.Tensor, which loses parameter level attributes + parameter_by_name = dict(module.named_parameters()) + # QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there + state_dict = module.state_dict() + + for config in module.embedding_bag_configs(): + table_name = config.name + assert ( + table_name in table_name_to_parameter_sharding + ), f"{table_name} not in table_name_to_parameter_sharding" + parameter_sharding = table_name_to_parameter_sharding[table_name] + if parameter_sharding.compute_kernel not in [ + kernel.value for kernel in EmbeddingComputeKernel + ]: + raise ValueError( + f"Compute kernel not supported {parameter_sharding.compute_kernel}" + ) + embedding_names: List[str] = [] + for feature_name in config.feature_names: + if shared_feature[feature_name]: + embedding_names.append(feature_name + "@" + config.name) + else: + embedding_names.append(feature_name) + + param_name = prefix + table_name + if suffix is not None: + param_name = f"{param_name}.{suffix}" + + assert param_name in parameter_by_name or param_name in state_dict + param = parameter_by_name.get(param_name, state_dict[param_name]) + + optimizer_params = getattr(param, "_optimizer_kwargs", [{}]) + optimizer_classes = getattr(param, "_optimizer_classes", [None]) + + assert ( + len(optimizer_classes) == 1 and len(optimizer_params) == 1 + ), f"Only support 1 optimizer, given {len(optimizer_classes)} optimizer classes \ + and {len(optimizer_params)} optimizer kwargs." + + optimizer_class = optimizer_classes[0] + optimizer_params = optimizer_params[0] + if optimizer_class: + optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type( + optimizer_class + ) + + per_table_fused_params = merge_fused_params(fused_params, optimizer_params) + per_table_fused_params = add_params_from_parameter_sharding( + per_table_fused_params, parameter_sharding + ) + per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + + sharding_info = EmbeddingShardingInfo( + embedding_config=EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=copy.deepcopy(config.feature_names), + pooling=config.pooling, + is_weighted=module.is_weighted(), + has_feature_processor=False, + embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + num_embeddings_post_pruning=( + getattr(config, "num_embeddings_post_pruning", None) + # TODO: Need to check if attribute exists for BC + ), + ), + param_sharding=parameter_sharding, + param=param, + fused_params=per_table_fused_params, + ) + sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( + sharding_info + ) + return sharding_type_to_sharding_infos + + @classmethod + def create_embedding_bag_sharding( + cls, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + permute_embeddings: bool = False, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> EmbeddingSharding[ + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ]: + """ + This is the main function to generate `EmbeddingSharding` instances based on sharding_type + so that the same sharding_type in one EBC would be fused. + """ + sharding_type = sharding_infos[0].param_sharding.sharding_type + + if device is not None and device.type == "meta": + replace_placement_with_meta_device(sharding_infos) + if sharding_type == ShardingType.TABLE_WISE.value: + return TwPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.ROW_WISE.value: + return RwPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + return DpPooledEmbeddingSharding(sharding_infos, env, device) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return TwRwPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return CwPooledEmbeddingSharding( + sharding_infos, + env, + device, + permute_embeddings=permute_embeddings, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: + return TwCwPooledEmbeddingSharding( + sharding_infos, + env, + device, + permute_embeddings=permute_embeddings, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + elif sharding_type == ShardingType.GRID_SHARD.value: + return GridPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + else: + raise ValueError(f"Sharding type not supported {sharding_type}") + + @staticmethod + def _pre_state_dict_hook( + self: "ShardedEmbeddingBagCollection", + prefix: str = "", + keep_vars: bool = False, + ) -> None: + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.flush() + + @staticmethod + def _pre_load_state_dict_hook( + self: "ShardedEmbeddingBagCollection", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + Modify the destination state_dict for model parallel + to transform from ShardedTensors/DTensors into tensors + """ + for table_name in self._model_parallel_name_to_local_shards.keys(): + key = f"{prefix}embedding_bags.{table_name}.weight" + # gather model shards from both DTensor and ShardedTensor maps + model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[ + table_name + ] + model_shards_dtensor = self._model_parallel_name_to_shards_wrapper[ + table_name + ] + # If state_dict[key] is already a ShardedTensor, use its local shards + if isinstance(state_dict[key], ShardedTensor): + local_shards = state_dict[key].local_shards() + if len(local_shards) == 0: + state_dict[key] = torch.empty(0) + else: + dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1] + # CW multiple shards are merged + if len(local_shards) > 1: + state_dict[key] = torch.cat( + [s.tensor.view(-1) for s in local_shards], dim=0 + ).view(-1, dim) + else: + state_dict[key] = local_shards[0].tensor.view(-1, dim) + elif isinstance(state_dict[key], DTensor): + shards_wrapper = state_dict[key].to_local() + local_shards = shards_wrapper.local_shards() + if len(local_shards) == 0: + state_dict[key] = torch.empty(0) + else: + dim = shards_wrapper.local_sizes()[0][1] + # CW multiple shards are merged + if len(local_shards) > 1: + state_dict[key] = torch.cat( + [s.view(-1) for s in local_shards], dim=0 + ).view(-1, dim) + else: + state_dict[key] = local_shards[0].view(-1, dim) + elif isinstance(state_dict[key], torch.Tensor): + local_shards = [] + if model_shards_sharded_tensor: + # splice according to sharded tensor metadata + for shard in model_shards_sharded_tensor: + # Extract shard size and offsets for splicing + shard_size = shard.metadata.shard_sizes + shard_offset = shard.metadata.shard_offsets + + # Prepare tensor by splicing and placing on appropriate device + spliced_tensor = state_dict[key][ + shard_offset[0] : shard_offset[0] + shard_size[0], + shard_offset[1] : shard_offset[1] + shard_size[1], + ] + + # Append spliced tensor into local shards + local_shards.append(spliced_tensor) + elif model_shards_dtensor: + # splice according to dtensor metadata + for tensor, shard_offset in zip( + model_shards_dtensor["local_tensors"], + model_shards_dtensor["local_offsets"], + ): + shard_size = tensor.size() + spliced_tensor = state_dict[key][ + shard_offset[0] : shard_offset[0] + shard_size[0], + shard_offset[1] : shard_offset[1] + shard_size[1], + ] + local_shards.append(spliced_tensor) + state_dict[key] = ( + torch.empty(0) + if not local_shards + else torch.cat(local_shards, dim=0) + ) + else: + raise RuntimeError( + f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" + ) + + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.purge() + + def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa + """ + This provides consistency between this class and the EmbeddingBagCollection's + nn.Module API calls (state_dict, named_modules, etc) + + Args: + skip_registering (bool): If True, skips registering state_dict hooks. This is useful + for dynamic sharding where the state_dict hooks do not need to be + reregistered when being resharded. Default is False. + + """ + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + for table_name in self._table_names: + self.embedding_bags[table_name] = nn.Module() + + self._model_parallel_name_to_local_shards = OrderedDict() + self._model_parallel_name_to_shards_wrapper = OrderedDict() + self._model_parallel_name_to_sharded_tensor = OrderedDict() + self._model_parallel_name_to_dtensor = OrderedDict() + + _model_parallel_name_to_compute_kernel: Dict[str, str] = {} + for ( + table_name, + parameter_sharding, + ) in self.module_sharding_plan.items(): + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + continue + _model_parallel_name_to_compute_kernel[table_name] = ( + parameter_sharding.compute_kernel + ) + if ( + parameter_sharding.compute_kernel + == EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value + ): + # Skip state_dict handling for CUSTOMIZED_KERNEL, this should be implemented + # in child class for the CUSTOMIZED_KERNEL + continue + self._model_parallel_name_to_local_shards[table_name] = [] + self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict( + [("local_tensors", []), ("local_offsets", [])] + ) + + self._name_to_table_size = {} + for table in self._embedding_bag_configs: + self._name_to_table_size[table.name] = ( + table.num_embeddings, + table.embedding_dim, + ) + + for lookup, sharding in zip(self._lookups, self._embedding_shardings): + if isinstance(sharding, DpPooledEmbeddingSharding): + # unwrap DDP + lookup = lookup.module + else: + # save local_shards for transforming MP params to DTensor + for key, v in lookup.state_dict().items(): + table_name = key[: -len(".weight")] + if ( + _model_parallel_name_to_compute_kernel[table_name] + == EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value + ): + continue + if isinstance(v, DTensor): + shards_wrapper = self._model_parallel_name_to_shards_wrapper[ + table_name + ] + local_shards_wrapper = v._local_tensor + shards_wrapper["local_tensors"].extend( + # pyre-ignore[16] + local_shards_wrapper.local_shards() + ) + shards_wrapper["local_offsets"].extend( + # pyre-ignore[16] + local_shards_wrapper.local_offsets() + ) + shards_wrapper["global_size"] = v.size() + shards_wrapper["global_stride"] = v.stride() + shards_wrapper["placements"] = v.placements + elif isinstance(v, ShardedTensor): + self._model_parallel_name_to_local_shards[table_name].extend( + v.local_shards() + ) + for ( + table_name, + tbe_slice, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `named_parameters_by_table`. + ) in lookup.named_parameters_by_table(): + self.embedding_bags[table_name].register_parameter("weight", tbe_slice) + + for table_name in self._model_parallel_name_to_local_shards.keys(): + local_shards = self._model_parallel_name_to_local_shards[table_name] + shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name] + # for shards that don't exist on this rank, register with empty tensor + if not hasattr(self.embedding_bags[table_name], "weight"): + self.embedding_bags[table_name].register_parameter( + "weight", nn.Parameter(torch.empty(0)) + ) + if ( + _model_parallel_name_to_compute_kernel[table_name] + != EmbeddingComputeKernel.DENSE.value + ): + self.embedding_bags[table_name].weight._in_backward_optimizers = [ + EmptyFusedOptimizer() + ] + + if self._output_dtensor: + assert _model_parallel_name_to_compute_kernel[table_name] not in { + EmbeddingComputeKernel.KEY_VALUE.value + } + if shards_wrapper_map["local_tensors"]: + self._model_parallel_name_to_dtensor[table_name] = ( + DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=shards_wrapper_map["local_tensors"], + local_offsets=shards_wrapper_map["local_offsets"], + ), + device_mesh=self._env.device_mesh, + placements=shards_wrapper_map["placements"], + shape=shards_wrapper_map["global_size"], + stride=shards_wrapper_map["global_stride"], + run_check=False, + ) + ) + else: + shape, stride = create_global_tensor_shape_stride_from_metadata( + none_throws(self.module_sharding_plan[table_name]), + ( + self._env.node_group_size + if isinstance(self._env, ShardingEnv2D) + else get_local_size(self._env.world_size) + ), + ) + # empty shard case + self._model_parallel_name_to_dtensor[table_name] = ( + DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[], + local_offsets=[], + ), + device_mesh=self._env.device_mesh, + run_check=False, + shape=shape, + stride=stride, + ) + ) + else: + # created ShardedTensors once in init, use in post_state_dict_hook + # note: at this point kvstore backed tensors don't own valid snapshots, so no read + # access is allowed on them. + + # create ShardedTensor from local shards and metadata avoding all_gather collective + sharding_spec = none_throws( + self.module_sharding_plan[table_name].sharding_spec + ) + + tensor_properties = TensorProperties( + dtype=( + data_type_to_dtype( + self._table_name_to_config[table_name].data_type + ) + ), + ) + + self._model_parallel_name_to_sharded_tensor[table_name] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=sharding_spec.build_metadata( + tensor_sizes=self._name_to_table_size[table_name], + tensor_properties=tensor_properties, + ), + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), + ) + ) + + def extract_sharded_kvtensors( + module: ShardedEmbeddingBagCollection, + ) -> OrderedDict[str, ShardedTensor]: + # retrieve all kvstore backed tensors + ret = OrderedDict() + for ( + table_name, + sharded_t, + ) in module._model_parallel_name_to_sharded_tensor.items(): + if _model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.KEY_VALUE.value + }: + ret[table_name] = sharded_t + return ret + + def post_state_dict_hook( + module: ShardedEmbeddingBagCollection, + destination: Dict[str, torch.Tensor], + prefix: str, + _local_metadata: Dict[str, Any], + ) -> None: + # Adjust dense MP + for ( + table_name, + sharded_t, + ) in module._model_parallel_name_to_sharded_tensor.items(): + destination_key = f"{prefix}embedding_bags.{table_name}.weight" + destination[destination_key] = sharded_t + for ( + table_name, + d_tensor, + ) in module._model_parallel_name_to_dtensor.items(): + destination_key = f"{prefix}embedding_bags.{table_name}.weight" + destination[destination_key] = d_tensor + + # kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid + # snapshot for read access. + sharded_kvtensors = extract_sharded_kvtensors(module) + if len(sharded_kvtensors) == 0: + return + + sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + for lookup, sharding in zip(module._lookups, module._embedding_shardings): + if not isinstance(sharding, DpPooledEmbeddingSharding): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + for key, v in lookup.get_named_split_embedding_weights_snapshot(): + assert key in sharded_kvtensors_copy + sharded_kvtensors_copy[key].local_shards()[0].tensor = v + for ( + table_name, + sharded_kvtensor, + ) in sharded_kvtensors_copy.items(): + destination_key = f"{prefix}embedding_bags.{table_name}.weight" + destination[destination_key] = sharded_kvtensor + + if not skip_registering: + self.register_state_dict_pre_hook(self._pre_state_dict_hook) + self._register_state_dict_hook(post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self._device and self._device.type == "meta": + return + + # Initialize embedding bags weights with init_fn + for table_config in self._embedding_bag_configs: + if self.module_sharding_plan[table_config.name].compute_kernel in { + EmbeddingComputeKernel.KEY_VALUE.value, + }: + continue + assert table_config.init_fn is not None + param = self.embedding_bags[f"{table_config.name}"].weight + # pyre-ignore + table_config.init_fn(param) + + sharding_type = self.module_sharding_plan[table_config.name].sharding_type + if sharding_type == ShardingType.DATA_PARALLEL.value: + pg = self._env.process_group + with torch.no_grad(): + dist.broadcast(param.data, src=0, group=pg) + def _create_input_dist( self, input_feature_names: List[str], ) -> None: feature_names: List[str] = [] - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._embedding_shardings: self._input_dists.append(sharding.create_input_dist()) - feature_names.extend( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - self._feature_splits.append( - len( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - ) + feature_names.extend(sharding.feature_names()) + self._feature_splits.append(len(sharding.feature_names())) if feature_names == input_feature_names: self._has_features_permute = False @@ -397,79 +1137,228 @@ def _create_input_dist( torch.tensor( self._features_order, device=self._device, dtype=torch.int32 ), + persistent=False, + ) + + def _init_mean_pooling_callback( + self, + input_feature_names: List[str], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]], + ) -> None: + # account for shared features + feature_names: List[str] = [ + feature_name + for sharding in self._embedding_shardings + for feature_name in sharding.feature_names() + ] + + for i, key in enumerate(feature_names): + if key not in self._kjt_key_indices: # index of first occurence + self._kjt_key_indices[key] = i + + keyed_tensor_ordering = [] + for key in self._embedding_names: + if "@" in key: + key = key.split("@")[0] + keyed_tensor_ordering.append(self._kjt_key_indices[key]) + self._kt_key_ordering = torch.tensor(keyed_tensor_ordering, device=self._device) + + if inverse_indices: + key_to_inverse_index = { + name: i for i, name in enumerate(inverse_indices[0]) + } + self._kjt_inverse_order = torch.tensor( + [key_to_inverse_index[key] for key in feature_names], + device=self._device, ) def _create_lookups( self, ) -> None: - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._embedding_shardings: self._lookups.append(sharding.create_lookup()) def _create_output_dist(self) -> None: - for sharding in self._sharding_type_to_sharding.values(): + embedding_shard_metadata: List[Optional[ShardMetadata]] = [] + for sharding in self._embedding_shardings: self._output_dists.append(sharding.create_output_dist(device=self._device)) self._embedding_names.extend(sharding.embedding_names()) self._embedding_dims.extend(sharding.embedding_dims()) + self._uncombined_embedding_names.extend( + sharding.uncombined_embedding_names() + ) + self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) + embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device) + + embedding_shard_offsets: List[int] = [ + meta.shard_offsets[1] if meta is not None else 0 + for meta in embedding_shard_metadata + ] + embedding_name_order: Dict[str, int] = {} + for i, name in enumerate(self._uncombined_embedding_names): + embedding_name_order.setdefault(name, i) + + permute_indices = sorted( + range(len(self._uncombined_embedding_names)), + key=lambda i: ( + embedding_name_order[self._uncombined_embedding_names[i]], + embedding_shard_offsets[i], + ), + ) + + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( + self._uncombined_embedding_dims, permute_indices, self._device + ) + + def _update_output_dist(self) -> None: + """ + This function is only used in update update_shards + """ + embedding_shard_metadata: List[Optional[ShardMetadata]] = [] + # TODO: Optimize to only go through embedding shardings with new ranks + self._output_dists: List[nn.Module] = [] + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._uncombined_embedding_names: List[str] = [] + self._uncombined_embedding_dims: List[int] = [] + for sharding in self._embedding_shardings: + # TODO: if sharding type of table completely changes, need to regenerate everything + self._embedding_names.extend(sharding.embedding_names()) + self._output_dists.append(sharding.create_output_dist(device=self._device)) + embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + self._embedding_dims.extend(sharding.embedding_dims()) + self._uncombined_embedding_names.extend( + sharding.uncombined_embedding_names() + ) + self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) + + embedding_shard_offsets: List[int] = [ + meta.shard_offsets[1] if meta is not None else 0 + for meta in embedding_shard_metadata + ] + embedding_name_order: Dict[str, int] = {} + for i, name in enumerate(self._uncombined_embedding_names): + embedding_name_order.setdefault(name, i) + + permute_indices = sorted( + range(len(self._uncombined_embedding_names)), + key=lambda i: ( + embedding_name_order[self._uncombined_embedding_names[i]], + embedding_shard_offsets[i], + ), + ) + + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( + self._uncombined_embedding_dims, permute_indices, self._device + ) + + def _create_inverse_indices_permute_indices( + self, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] + ) -> None: + assert ( + inverse_indices is not None + ), "inverse indices must be provided from KJT if using variable batch size per feature." + index_per_name = {name: i for i, name in enumerate(inverse_indices[0])} + permute_indices = [ + index_per_name[name.split("@")[0]] + for name in self._uncombined_embedding_names + ] + if len(permute_indices) != len(index_per_name) or permute_indices != sorted( + permute_indices + ): + self._inverse_indices_permute_indices = _pin_and_move( + torch.tensor(permute_indices), + inverse_indices[1].device, + ) # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor - ) -> Awaitable[SparseFeaturesList]: + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], + ) -> Awaitable[Awaitable[KJTList]]: + """ + This is the main API called in train_pipeline where we want to do the input_dist + in advance + """ + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if len(self._features_order) > 0: + feature_keys = [feature_keys[i] for i in self._features_order] + self._has_features_permute = False # feature_keys are in order + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] + ctx.variable_batch_per_feature = features.variable_stride_per_key() + ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: self._create_input_dist(features.keys()) self._has_uninitialized_input_dist = False + if ctx.variable_batch_per_feature: + self._create_inverse_indices_permute_indices(ctx.inverse_indices) + if self._has_mean_pooling_callback: + self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) with torch.no_grad(): if self._has_features_permute: features = features.permute( self._features_order, - # pyre-ignore [6] + # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` + # but got `Union[Module, Tensor]`. self._features_order_tensor, ) + if self._has_mean_pooling_callback: + ctx.divisor = _create_mean_pooling_divisor( + lengths=features.lengths(), + stride=features.stride(), + keys=features.keys(), + offsets=features.offsets(), + pooling_type_to_rs_features=self._pooling_type_to_rs_features, + stride_per_key=features.stride_per_key(), + dim_per_key=self._dim_per_key, # pyre-ignore[6] + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, + variable_batch_per_feature=ctx.variable_batch_per_feature, + kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6] + kjt_key_indices=self._kjt_key_indices, + kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6] + inverse_indices=ctx.inverse_indices, + weights=features.weights_or_none(), + ) + features_by_shards = features.split( self._feature_splits, ) awaitables = [] - for module, features_by_shard in zip(self._input_dists, features_by_shards): - lengths_awaitable = module( - SparseFeatures( - id_list_features=None - if self._is_weighted - else features_by_shard, - id_score_list_features=features_by_shard - if self._is_weighted - else None, + for input_dist, features_by_shard, sharding_type in zip( + self._input_dists, + features_by_shards, + self._sharding_types, + ): + with maybe_annotate_embedding_event( + EmbeddingEvent.KJT_SPLITS_DIST, + self._module_fqn, + sharding_type, + ): + awaitables.append(input_dist(features_by_shard)) + + ctx.sharding_contexts.append( + EmbeddingShardingContext( + batch_size_per_feature_pre_a2a=features_by_shard.stride_per_key(), + variable_batch_per_feature=features_by_shard.variable_stride_per_key(), ) ) - indices_awaitable = lengths_awaitable.wait() - if isinstance(indices_awaitable, SparseFeaturesIndicesAwaitable): - if indices_awaitable._id_list_features_awaitable is not None: - batch_size_per_rank = ( - # pyre-fixme[16] - indices_awaitable._id_list_features_awaitable._batch_size_per_rank - ) - elif ( - indices_awaitable._id_score_list_features_awaitable is not None - ): - batch_size_per_rank = ( - indices_awaitable._id_score_list_features_awaitable._batch_size_per_rank - ) - else: - batch_size_per_rank = [] - ctx.sharding_contexts.append( - EmbeddingShardingContext( - batch_size_per_rank=batch_size_per_rank, - ) - ) - else: - ctx.sharding_contexts.append(None) - awaitables.append(indices_awaitable) - return SparseFeaturesListAwaitable(awaitables) + return KJTListSplitsAwaitable( + awaitables, ctx, self._module_fqn, self._sharding_types + ) def compute( self, ctx: EmbeddingBagCollectionContext, - dist_input: SparseFeaturesList, + dist_input: KJTList, ) -> List[torch.Tensor]: + """ + this function is not used in general practice, it's only called by the base class + ShardedModule.compute_and_output_dist to do the basic function + """ return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] def output_dist( @@ -477,117 +1366,241 @@ def output_dist( ctx: EmbeddingBagCollectionContext, output: List[torch.Tensor], ) -> LazyAwaitable[KeyedTensor]: - return EmbeddingBagCollectionAwaitable( - awaitables=[ - dist(embeddings, sharding_ctx) - for dist, sharding_ctx, embeddings in zip( - self._output_dists, - ctx.sharding_contexts, - output, + batch_size_per_feature_pre_a2a = [] + awaitables = [] + for dist, sharding_context, embeddings in zip( + self._output_dists, + ctx.sharding_contexts, + output, + ): + awaitables.append(dist(embeddings, sharding_context)) + if sharding_context: + batch_size_per_feature_pre_a2a.extend( + sharding_context.batch_size_per_feature_pre_a2a ) - ], - embedding_dims=self._embedding_dims, - embedding_names=self._embedding_names, - ) + + if ctx.variable_batch_per_feature: + assert ( + ctx.inverse_indices is not None + ), "inverse indices must be provided from KJT if using variable batch size per feature." + awaitable = VariableBatchEmbeddingBagCollectionAwaitable( + awaitables=awaitables, + inverse_indices=ctx.inverse_indices, + inverse_indices_permute_indices=self._inverse_indices_permute_indices, + batch_size_per_feature_pre_a2a=batch_size_per_feature_pre_a2a, + uncombined_embedding_dims=self._uncombined_embedding_dims, + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, + permute_op=self._permute_op, + ) + else: + awaitable = EmbeddingBagCollectionAwaitable( + awaitables=awaitables, + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ) + + # register callback if there are features that need mean pooling + if self._has_mean_pooling_callback: + awaitable.callbacks.append( + partial(_apply_mean_pooling, divisor=ctx.divisor) + ) + + return awaitable def compute_and_output_dist( - self, ctx: EmbeddingBagCollectionContext, input: SparseFeaturesList + self, ctx: EmbeddingBagCollectionContext, input: KJTList ) -> LazyAwaitable[KeyedTensor]: - return EmbeddingBagCollectionAwaitable( - awaitables=[ - dist(lookup(features), sharding_ctx) - for lookup, dist, sharding_ctx, features in zip( - self._lookups, - self._output_dists, - ctx.sharding_contexts, - input, - ) - ], - embedding_dims=self._embedding_dims, - embedding_names=self._embedding_names, - ) + """ + the main API called in PipelineForward, where the shardedEBC's forward is swapped + see _rewrite_model in train_pipeline for details + """ + batch_size_per_feature_pre_a2a = [] + awaitables = [] + + # No usage of zip for dynamo + for i in range(len(self._lookups)): + lookup = self._lookups[i] + dist = self._output_dists[i] + sharding_context = ctx.sharding_contexts[i] + features = input[i] + sharding_type = self._sharding_types[i] + + with maybe_annotate_embedding_event( + EmbeddingEvent.LOOKUP, + self._module_fqn, + sharding_type, + ): + embs = lookup(features) - # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for lookup in self._lookups: - lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) - return destination + with maybe_annotate_embedding_event( + EmbeddingEvent.OUTPUT_DIST, + self._module_fqn, + sharding_type, + ): + awaitables.append(dist(embs, sharding_context)) - def named_modules( - self, - memo: Optional[Set[nn.Module]] = None, - prefix: str = "", - remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] + if sharding_context: + batch_size_per_feature_pre_a2a.extend( + sharding_context.batch_size_per_feature_pre_a2a + ) - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for lookup in self._lookups: - yield from lookup.named_parameters( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate + if ctx.variable_batch_per_feature: + assert ( + ctx.inverse_indices is not None + ), "inverse indices must be provided from KJT if using variable batch size per feature." + awaitable = VariableBatchEmbeddingBagCollectionAwaitable( + awaitables=awaitables, + inverse_indices=ctx.inverse_indices, + inverse_indices_permute_indices=self._inverse_indices_permute_indices, + batch_size_per_feature_pre_a2a=batch_size_per_feature_pre_a2a, + uncombined_embedding_dims=self._uncombined_embedding_dims, + embedding_names=self._embedding_names, + embedding_dims=self._embedding_dims, + permute_op=self._permute_op, + module_fqn=self._module_fqn, + sharding_types=self._sharding_types, + ) + else: + awaitable = EmbeddingBagCollectionAwaitable( + awaitables=awaitables, + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + module_fqn=self._module_fqn, + sharding_types=self._sharding_types, ) - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for lookup, sharding_type in zip( - self._lookups, self._sharding_type_to_sharding.keys() - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - continue - for name, _ in lookup.named_parameters( - append_prefix(prefix, "embedding_bags") - ): - yield name - - def named_buffers( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for lookup in self._lookups: - yield from lookup.named_buffers( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate + # register callback if there are features that need mean pooling + if self._has_mean_pooling_callback: + awaitable.callbacks.append( + partial(_apply_mean_pooling, divisor=ctx.divisor) ) - # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` - # inconsistently. - def load_state_dict( + return awaitable + + def update_shards( self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - for lookup in self._lookups: - missing, unexpected = lookup.load_state_dict( - filter_state_dict(state_dict, "embedding_bags"), - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys + changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta + env: ShardingEnv, + device: Optional[torch.device], + ) -> None: + """ + This is the main API used in sharder.reshard, currently only support redistribution + of existing shards (across different ranks, ideally from hot ranks to cold ranks) + Update shards for this module based on the changed_sharding_params. This will: + 1. Move current lookup tensors to CPU + 2. Purge lookups + 3. Call shards_all_2_all containing collective to redistribute tensors + 4. Update state_dict and other attributes to reflect new placements and shards + 5. Create new lookups, and load in updated state_dict + + Args: + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved. + env (ShardingEnv): The sharding environment for the module. + device (Optional[torch.device]): The device to place the updated module on. + """ + + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + current_state = self.state_dict() + # TODO: Save Optimizers + + saved_weights = {} + # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again + for i, lookup in enumerate(self._lookups): + for attribute, tbe_module in lookup.named_modules(): + if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen: + saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu() + # Note: lookup.purge should delete tbe_module and weights + # del tbe_module.weights + # del tbe_module + # pyre-ignore + lookup.purge() + + # Deleting all lookups + self._lookups.clear() + + local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( + module=self, + state_dict=current_state, + device=device, # pyre-ignore + changed_sharding_params=changed_sharding_params, + env=env, + extend_shard_name=self.extend_shard_name, ) - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - for lookup in self._lookups: - lookup.sparse_grad_parameter_names( - destination, append_prefix(prefix, "embedding_bags") + current_state = update_state_dict_post_resharding( + state_dict=current_state, + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, + output_tensor=local_output_tensor, + new_sharding_params=changed_sharding_params, + curr_rank=dist.get_rank(), + extend_shard_name=self.extend_shard_name, + ) + + for name, param in changed_sharding_params.items(): + self.module_sharding_plan[name] = param + # TODO: Support detecting old sharding type when sharding type is changing + for sharding_info in self.sharding_type_to_sharding_infos[ + param.sharding_type + ]: + if sharding_info.embedding_config.name == name: + sharding_info.param_sharding = param + + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() + ) + # TODO: Optimize to update only the changed embedding shardings + self._embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ] = [ + self.create_embedding_bag_sharding( + embedding_configs, + env, + device, + permute_embeddings=True, + qcomm_codecs_registry=self.qcomm_codecs_registry, ) - return destination + for embedding_configs in self.sharding_type_to_sharding_infos.values() + ] + + self._create_lookups() + self._update_output_dist() + + if env.process_group and dist.get_backend(env.process_group) != "fake": + self._initialize_torch_state(skip_registering=True) + + self.load_state_dict(current_state) + + # update optimizer + optims = [] + for lookup in self._lookups: + for _, tbe_module in lookup.named_modules(): + if isinstance(tbe_module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for ( + param_key, + weight, + ) in tbe_module.fused_optimizer.params.items(): + # pyre-fixme[16]: `Mapping` has no attribute `__setitem__` + params["embedding_bags." + param_key] = weight + tbe_module.fused_optimizer.params = params + optims.append(("", tbe_module.fused_optimizer)) + + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + update_module_sharding_plan(self, changed_sharding_params) + return @property def fused_optimizer(self) -> KeyedOptimizer: @@ -596,6 +1609,14 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> EmbeddingBagCollectionContext: return EmbeddingBagCollectionContext() + @staticmethod + def extend_shard_name(shard_name: str) -> str: + return f"embedding_bags.{shard_name}.weight" + + @property + def unsharded_module_type(self) -> Type[EmbeddingBagCollection]: + return EmbeddingBagCollection + class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]): """ @@ -608,6 +1629,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingBagCollection: return ShardedEmbeddingBagCollection( module=module, @@ -616,7 +1638,7 @@ def shard( fused_params=self.fused_params, device=device, qcomm_codecs_registry=self.qcomm_codecs_registry, - variable_batch_size=self.variable_batch_size, + module_fqn=module_fqn, ) def shardable_parameters( @@ -627,6 +1649,33 @@ def shardable_parameters( for name, param in module.embedding_bags.named_parameters() } + def reshard( + self, + sharded_module: ShardedEmbeddingBagCollection, + changed_shard_to_params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + """ + Updates the sharded module in place based on the changed_shard_to_params + which contains the new ParameterSharding with different shard placements. + + Args: + sharded_module (ShardedEmbeddingBagCollection): The module to update + changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved + env (ShardingEnv): The sharding environment + device (Optional[torch.device]): The device to place the updated module on + + Returns: + ShardedEmbeddingBagCollection: The updated sharded module + """ + + if len(changed_shard_to_params) > 0: + sharded_module.update_shards(changed_shard_to_params, env, device) + return sharded_module + @property def module_type(self) -> Type[EmbeddingBagCollection]: return EmbeddingBagCollection @@ -646,7 +1695,9 @@ def _wait_impl(self) -> torch.Tensor: class ShardedEmbeddingBag( - ShardedModule[SparseFeatures, torch.Tensor, torch.Tensor, NullShardedModuleContext], + ShardedEmbeddingModule[ + KeyedJaggedTensor, torch.Tensor, torch.Tensor, NullShardedModuleContext + ], FusedOptimizerModule, ): """ @@ -694,9 +1745,8 @@ def __init__( ) self._embedding_sharding: EmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor - ] = create_embedding_bag_sharding( - sharding_type=self.parameter_sharding.sharding_type, + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ] = self.create_embedding_bag_sharding( sharding_infos=[ EmbeddingShardingInfo( embedding_config=embedding_table_config, @@ -726,6 +1776,25 @@ def __init__( optims.append(("", module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + @classmethod + def create_embedding_bag_sharding( + cls, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + permute_embeddings: bool = False, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> EmbeddingSharding[ + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ]: + return ShardedEmbeddingBagCollection.create_embedding_bag_sharding( + sharding_infos=sharding_infos, + env=env, + device=device, + permute_embeddings=permute_embeddings, + qcomm_codecs_registry=qcomm_codecs_registry, + ) + # pyre-ignore [14] def input_dist( self, @@ -733,7 +1802,7 @@ def input_dist( input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None, - ) -> Awaitable[SparseFeatures]: + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: if per_sample_weights is None: per_sample_weights = torch.ones_like(input, dtype=torch.float) features = KeyedJaggedTensor( @@ -742,15 +1811,10 @@ def input_dist( offsets=offsets, weights=per_sample_weights, ) - return self._input_dist( - SparseFeatures( - id_list_features=None, - id_score_list_features=features, - ) - ).wait() + return self._input_dist(features) def compute( - self, ctx: NullShardedModuleContext, dist_input: SparseFeatures + self, ctx: NullShardedModuleContext, dist_input: KeyedJaggedTensor ) -> torch.Tensor: return self._lookup(dist_input) @@ -835,20 +1899,6 @@ def load_state_dict( missing_keys=missing_keys, unexpected_keys=unexpected_keys ) - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - # pyre-ignore [29] - lookup_sparse_grad_parameter_names = self._lookup.sparse_grad_parameter_names( - None, "" - ) - for name in lookup_sparse_grad_parameter_names: - destination.append(name.split(".")[-1]) - return destination - @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -856,6 +1906,10 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> NullShardedModuleContext: return NullShardedModuleContext() + @property + def unsharded_module_type(self) -> Type[nn.EmbeddingBag]: + return nn.EmbeddingBag + class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]): """ @@ -868,6 +1922,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingBag: return ShardedEmbeddingBag(module, params, env, self.fused_params, device) @@ -877,3 +1932,93 @@ def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Paramete @property def module_type(self) -> Type[nn.EmbeddingBag]: return nn.EmbeddingBag + + +def _create_mean_pooling_divisor( + lengths: torch.Tensor, + keys: List[str], + offsets: torch.Tensor, + stride: int, + stride_per_key: List[int], + dim_per_key: torch.Tensor, + pooling_type_to_rs_features: Dict[str, List[str]], + embedding_names: List[str], + embedding_dims: List[int], + variable_batch_per_feature: bool, + kjt_inverse_order: torch.Tensor, + kjt_key_indices: Dict[str, int], + kt_key_ordering: torch.Tensor, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + weights: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with record_function("## ebc create mean pooling callback ##"): + batch_size = ( + none_throws(inverse_indices)[1].size(dim=1) + if variable_batch_per_feature + else stride + ) + + if weights is not None: + # if we have weights, lengths is the sum of weights by offsets for feature + lengths = torch.ops.fbgemm.segment_sum_csr(1, offsets.int(), weights) + + if variable_batch_per_feature: + inverse_indices = none_throws(inverse_indices) + device = inverse_indices[1].device + inverse_indices_t = inverse_indices[1] + if len(keys) != len(inverse_indices[0]): + inverse_indices_t = torch.index_select( + inverse_indices[1], 0, kjt_inverse_order + ) + offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[ + :-1 + ].unsqueeze(-1) + indices = (inverse_indices_t + offsets).flatten() + lengths = torch.index_select(input=lengths, dim=0, index=indices) + + # only convert the sum pooling features to be 1 lengths + lengths = lengths.clone() + for feature in pooling_type_to_rs_features[PoolingType.SUM.value]: + feature_index = kjt_key_indices[feature] + feature_index = feature_index * batch_size + lengths[feature_index : feature_index + batch_size] = 1 + + if len(embedding_names) != len(keys): + lengths = torch.index_select( + lengths.reshape(-1, batch_size), + 0, + kt_key_ordering, + ).reshape(-1) + + # transpose to align features with keyed tensor dim_per_key + lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features] + output_size = sum(embedding_dims) + + divisor = torch.repeat_interleave( + input=lengths, + repeats=dim_per_key, + dim=1, + output_size=output_size, + ) + eps = 1e-6 # used to safe guard against 0 division + divisor = divisor + eps + return divisor.detach() + + +def _apply_mean_pooling( + keyed_tensor: KeyedTensor, divisor: torch.Tensor +) -> KeyedTensor: + """ + Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes. + This function is applied as a callback to the awaitable + """ + with record_function("## ebc apply mean pooling ##"): + mean_pooled_values = ( + keyed_tensor.values() / divisor + ) # [batch size, num_features * embedding dim] + return KeyedTensor( + keys=keyed_tensor.keys(), + values=mean_pooled_values, + length_per_key=keyed_tensor.length_per_key(), + key_dim=1, + ) diff --git a/torchrec/distributed/fbgemm_qcomm_codec.py b/torchrec/distributed/fbgemm_qcomm_codec.py index e2adeae5c..f097ab73d 100644 --- a/torchrec/distributed/fbgemm_qcomm_codec.py +++ b/torchrec/distributed/fbgemm_qcomm_codec.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import copy @@ -32,6 +34,7 @@ class CommType(Enum): BF16 = "bf16" FP8 = "fp8" INT8 = "int8" + MX4 = "mx4" def __str__(self) -> str: return self.value @@ -44,6 +47,7 @@ def comm_type_to_sparse_type(comm_type: CommType) -> SparseType: CommType.BF16: SparseType.BF16, CommType.FP8: SparseType.FP8, CommType.INT8: SparseType.INT8, + CommType.MX4: SparseType.MX4, }[comm_type] @@ -62,21 +66,63 @@ class QCommsConfig: forward_loss_scale: Optional[float] = None backward_loss_scale: Optional[float] = None fp8_quantize_dim: Optional[int] = None + fp8_quantize_dim_bwd: Optional[int] = None + fp8_bwd_uses_143: Optional[bool] = False + mx4_quantize_dim: Optional[int] = None + mx4_quantize_dim_bwd: Optional[int] = None def __post_init__(self) -> None: if ( self.forward_precision != CommType.FP8 and self.backward_precision != CommType.FP8 - and self.fp8_quantize_dim is not None + and ( + self.fp8_quantize_dim is not None + or self.fp8_quantize_dim_bwd is not None + ) + ): + logger.warning( + f"fp8_quantize_dim is set to {self.fp8_quantize_dim} and fp8_quantize_dim_bwd is set to {self.fp8_quantize_dim_bwd} but no FP8 precision is found in forward or backward precisions, resetting to None" + ) + if ( + self.backward_precision == CommType.FP8 + and self.fp8_quantize_dim_bwd is None + ): + self.fp8_quantize_dim_bwd = self.fp8_quantize_dim + logger.warning( + f"No override of FP8 bwd row dim, using general FP8 row dim for backward: {self.fp8_quantize_dim_bwd} " + ) + + if ( + self.forward_precision != CommType.MX4 + and self.backward_precision != CommType.MX4 + and ( + self.mx4_quantize_dim is not None + or self.mx4_quantize_dim_bwd is not None + ) ): - raise ValueError( - f"fp8_quantize_dim is set to {self.fp8_quantize_dim} but no FP8 precision is found in forward and backward precisions" + self.mx4_quantize_dim = None + self.mx4_quantize_dim_bwd = None + logger.warning( + f"mx4_quantize_dim is set to {self.mx4_quantize_dim} and mx4_quantize_dim_bwd is set to {self.mx4_quantize_dim_bwd} but no MX4 precision is found in forward or backward precisions" + ) + if ( + self.backward_precision == CommType.MX4 + and self.mx4_quantize_dim_bwd is None + ): + self.mx4_quantize_dim_bwd = self.mx4_quantize_dim + logger.warning( + f"No override of MX4 bwd row dim, using general MX4 row dim for backward: {self.mx4_quantize_dim_bwd} " ) def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCodecs: codecs = QuantizedCommCodecs() if qcomms_config is not None: + row_dim = None + if qcomms_config.forward_precision == CommType.FP8: + row_dim = qcomms_config.fp8_quantize_dim + elif qcomms_config.forward_precision == CommType.MX4: + row_dim = qcomms_config.mx4_quantize_dim codecs.forward = cast( QuantizedCommCodec[QuantizationContext], FbgemmQuantizedCommCodec( @@ -85,11 +131,14 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode ), loss_scale=qcomms_config.forward_loss_scale, is_fwd=True, - row_dim=qcomms_config.fp8_quantize_dim - if qcomms_config.forward_precision == CommType.FP8 - else None, + row_dim=row_dim, ), ) + row_dim_bwd = None + if qcomms_config.backward_precision == CommType.FP8: + row_dim_bwd = qcomms_config.fp8_quantize_dim_bwd + elif qcomms_config.backward_precision == CommType.MX4: + row_dim_bwd = qcomms_config.mx4_quantize_dim_bwd codecs.backward = cast( QuantizedCommCodec[QuantizationContext], FbgemmQuantizedCommCodec( @@ -97,10 +146,11 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode qcomms_config.backward_precision ), loss_scale=qcomms_config.backward_loss_scale, - is_fwd=False, - row_dim=qcomms_config.fp8_quantize_dim - if qcomms_config.backward_precision == CommType.FP8 - else None, + is_fwd=( + True if qcomms_config.fp8_bwd_uses_143 else False + ), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3 + # if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2 + row_dim=row_dim_bwd, ), ) return codecs @@ -110,7 +160,7 @@ def get_qcomm_codecs_registry( qcomms_config: QCommsConfig, comm_ops: Optional[List[CommOp]] = None, device: Optional[torch.device] = None, -) -> Dict[str, QuantizedCommCodecs]: +) -> Optional[Dict[str, QuantizedCommCodecs]]: """ This method constructs QuantizedCommCodecs from a given QCommConfig. It assumes that you want to use the same QComm configs for all comm-types passed in. @@ -129,6 +179,12 @@ def get_qcomm_codecs_registry( device=torch.device("cuda")) """ + if ( + qcomms_config.forward_precision == CommType.FP32 + and qcomms_config.backward_precision == CommType.FP32 + ): + return None + if device is None: device = torch.device("cuda") @@ -143,14 +199,14 @@ def get_qcomm_codecs_registry( qcomm_config_copy = copy.deepcopy(qcomms_config) # TODO: On H100, FP8 types might be natively supported, in which case we should check for that arch type and not fallback. if comm_op == CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER: - if qcomm_config_copy.forward_precision == CommType.FP8: + if qcomm_config_copy.forward_precision in [CommType.FP8, CommType.MX4]: logger.warning( - "FP8 is not supported for reduce scatter's forward - falling back to FP16" + "FP8/MX4 is not supported for reduce scatter's forward - falling back to FP16" ) qcomm_config_copy.forward_precision = CommType.FP16 - if qcomm_config_copy.backward_precision == CommType.FP8: + if qcomm_config_copy.backward_precision in [CommType.FP8, CommType.MX4]: logger.warning( - "FP8 is not supported for reduce scatter's backward - falling back to BF16" + "FP8/MX4 is not supported for reduce scatter's backward - falling back to BF16" ) qcomm_config_copy.backward_precision = CommType.BF16 diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py new file mode 100644 index 000000000..872fa6aa6 --- /dev/null +++ b/torchrec/distributed/fp_embeddingbag.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from functools import partial +from typing import Any, Dict, Iterator, List, Optional, Type, Union + +import torch +from torch import nn + +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + KJTList, + ShardedEmbeddingModule, +) +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionContext, + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, + ShardingType, +) +from torchrec.distributed.utils import append_prefix, init_parameters +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection +from torchrec.modules.fp_embedding_modules import ( + apply_feature_processors_to_kjt, + FeatureProcessedEmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor: + kt._values.add_(no_op_tensor) + return kt + + +class ShardedFeatureProcessedEmbeddingBagCollection( + ShardedEmbeddingModule[ + KJTList, List[torch.Tensor], KeyedTensor, EmbeddingBagCollectionContext + ] +): + def __init__( + self, + module: FeatureProcessedEmbeddingBagCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ebc_sharder: EmbeddingBagCollectionSharder, + env: ShardingEnv, + device: torch.device, + module_fqn: Optional[str] = None, + ) -> None: + super().__init__() + + self._device = device + self._env = env + + self._embedding_bag_collection: ShardedEmbeddingBagCollection = ( + ebc_sharder.shard( + module._embedding_bag_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + module_fqn=module_fqn, + ) + ) + + self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups + + self._is_collection: bool = False + self._feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection] + if isinstance(module._feature_processors, FeatureProcessorsCollection): + self._feature_processors = module._feature_processors + self._is_collection = True + else: + self._feature_processors = torch.nn.ModuleDict( + module._feature_processors.items() + ) + self._is_collection = False + + init_parameters(self._feature_processors, device) + self._no_op_zero: torch.Tensor = torch.zeros((1,), device=self._device) + + # pyre-ignore + def input_dist( + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + ) -> Awaitable[Awaitable[KJTList]]: + return self._embedding_bag_collection.input_dist(ctx, features) + + def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: + kjt_list = [] + for features in dist_input: + if self._is_collection: + kjt_list.append(self._feature_processors(features)) + else: + kjt_list.append( + apply_feature_processors_to_kjt( + features, + self._feature_processors, + ) + ) + return KJTList(kjt_list) + + def compute( + self, + ctx: EmbeddingBagCollectionContext, + dist_input: KJTList, + ) -> List[torch.Tensor]: + + fp_features = self.apply_feature_processors_to_kjt_list(dist_input) + return self._embedding_bag_collection.compute(ctx, fp_features) + + def output_dist( + self, + ctx: EmbeddingBagCollectionContext, + output: List[torch.Tensor], + ) -> LazyAwaitable[KeyedTensor]: + lazy_awaitable_kt = self._embedding_bag_collection.output_dist(ctx, output) + return self.add_fp_params_grad_sync_callback(lazy_awaitable_kt) + + def compute_and_output_dist( + self, ctx: EmbeddingBagCollectionContext, input: KJTList + ) -> LazyAwaitable[KeyedTensor]: + fp_features = self.apply_feature_processors_to_kjt_list(input) + lazy_awaitable_kt = self._embedding_bag_collection.compute_and_output_dist( + ctx, fp_features + ) + return self.add_fp_params_grad_sync_callback(lazy_awaitable_kt) + + def add_fp_params_grad_sync_callback( + self, lazy_awaitable_kt: LazyAwaitable[KeyedTensor] + ) -> LazyAwaitable[KeyedTensor]: + # This will ensure that all feature processor parameters participate in the + # autograd graph across all ranks. This will protect from mismatched collective + # calls order when using DistributedDataParallel over feature processors. + no_op_tensor = ( + self._no_op_zero + * torch.cat( + [x.flatten() for x in self._feature_processors.parameters()] + ).sum() + ) + lazy_awaitable_kt.callbacks.append( + partial(param_dp_sync, no_op_tensor=no_op_tensor) + ) + return lazy_awaitable_kt + + def create_context(self) -> EmbeddingBagCollectionContext: + return self._embedding_bag_collection.create_context() + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + if "_embedding_bag_collection" in fqn: + yield append_prefix(prefix, fqn) + + +class FeatureProcessedEmbeddingBagCollectionSharder( + BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] +): + def __init__( + self, + ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._ebc_sharder: EmbeddingBagCollectionSharder = ( + ebc_sharder + or EmbeddingBagCollectionSharder( + qcomm_codecs_registry=self.qcomm_codecs_registry + ) + ) + + def shard( + self, + module: FeatureProcessedEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedFeatureProcessedEmbeddingBagCollection: + + if device is None: + device = torch.device("cuda") + + return ShardedFeatureProcessedEmbeddingBagCollection( + module, + params, + ebc_sharder=self._ebc_sharder, + env=env, + device=device, + module_fqn=module_fqn, + ) + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._ebc_sharder.fused_params + + def shardable_parameters( + self, module: FeatureProcessedEmbeddingBagCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._ebc_sharder.shardable_parameters(module._embedding_bag_collection) + + @property + def module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]: + return FeatureProcessedEmbeddingBagCollection + + def sharding_types(self, compute_device_type: str) -> List[str]: + if compute_device_type in {"mtia"}: + return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value] + + # No row wise because position weighted FP and RW don't play well together. + types = [ + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + + return types diff --git a/torchrec/distributed/fused_embedding.py b/torchrec/distributed/fused_embedding.py index c92d7c612..1968e7fa0 100644 --- a/torchrec/distributed/fused_embedding.py +++ b/torchrec/distributed/fused_embedding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Dict, Iterator, List, Optional, Type import torch @@ -94,6 +96,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedFusedEmbeddingCollection: return ShardedFusedEmbeddingCollection(module, params, env, device) diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index 31a4c31b6..43eeda323 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -5,7 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Iterator, List, Optional, Type +# pyre-strict + +from typing import Dict, List, Optional, Type import torch from torch import nn @@ -23,7 +25,6 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import append_prefix from torchrec.modules.fused_embedding_modules import ( convert_optimizer_type_and_kwargs, FusedEmbeddingBagCollection, @@ -64,17 +65,18 @@ def __init__( ) for index, (sharding, lookup) in enumerate( - zip(self._sharding_type_to_sharding.values(), self._lookups) + zip(self._embedding_shardings, self._lookups) ): if isinstance(sharding, DpPooledEmbeddingSharding): self._lookups[index] = DistributedDataParallel( module=lookup, - device_ids=[device], + device_ids=[device] if device is not None else None, process_group=env.process_group, gradient_as_bucket_view=True, broadcast_buffers=False, static_graph=True, ) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. self._lookups[index]._register_fused_optim( optimizer_type, **optimizer_kwargs ) @@ -83,15 +85,6 @@ def __init__( # We need to ensure that a checkpoint from DDP and a checkpoint from a # model parallel version are compatible. - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - # different than ShardedEmbeddingBagCollection - we consider DDP to be "sharded", so that it doesn't get wrapped up in ddp again - # semantics of this is actually "parameters that don't need to have their gradients reduced" - for lookup, _ in zip(self._lookups, self._sharding_type_to_sharding.keys()): - for name, _ in lookup.named_parameters( - append_prefix(prefix, "embedding_bags") - ): - yield name - class FusedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FusedEmbeddingBagCollection] @@ -102,6 +95,7 @@ def shard( params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedEmbeddingBagCollection: return ShardedFusedEmbeddingBagCollection( diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py new file mode 100644 index 000000000..71b6b4786 --- /dev/null +++ b/torchrec/distributed/fused_params.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Iterable, Optional + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from torchrec.distributed.embedding_types import GroupedEmbeddingConfig +from torchrec.distributed.types import BoundsCheckMode + +FUSED_PARAM_REGISTER_TBE_BOOL: str = "__register_tbes_in_named_modules" +FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = ( + "__register_quant_state_dict_split_scale_bias" +) +FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment" +FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode" + +# Force lengths to offsets conversion before TBE lookup. Helps with performance +# with certain ways to split models. +FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup" + + +class TBEToRegisterMixIn: + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + raise NotImplementedError + + +def get_tbes_to_register_from_iterable( + iterable: Iterable[torch.nn.Module], +) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] = {} + for m in iterable: + if isinstance(m, TBEToRegisterMixIn): + tbes.update(m.get_tbes_to_register()) + return tbes + + +def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool: + return ( + fused_params + and FUSED_PARAM_REGISTER_TBE_BOOL in fused_params + and fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] + ) + + +def get_fused_param_tbe_row_alignment( + fused_params: Optional[Dict[str, Any]], +) -> Optional[int]: + if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params: + return None + else: + return fused_params[FUSED_PARAM_TBE_ROW_ALIGNMENT] + + +def fused_param_bounds_check_mode( + fused_params: Optional[Dict[str, Any]], +) -> Optional[BoundsCheckMode]: + if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params: + return None + else: + return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE] + + +def fused_param_lengths_to_offsets_lookup( + fused_params: Optional[Dict[str, Any]], +) -> bool: + if ( + fused_params is None + or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params + ): + return False + else: + return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP] + + +def is_fused_param_quant_state_dict_split_scale_bias( + fused_params: Optional[Dict[str, Any]], +) -> bool: + return ( + fused_params + and FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS in fused_params + and fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] + ) + + +def tbe_fused_params( + fused_params: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + if not fused_params: + return None + + fused_params_for_tbe = dict(fused_params) + if FUSED_PARAM_REGISTER_TBE_BOOL in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_REGISTER_TBE_BOOL) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS) + if FUSED_PARAM_TBE_ROW_ALIGNMENT in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT) + if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE) + if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP) + + return fused_params_for_tbe diff --git a/torchrec/distributed/global_settings.py b/torchrec/distributed/global_settings.py new file mode 100644 index 000000000..2b957965c --- /dev/null +++ b/torchrec/distributed/global_settings.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +PROPOGATE_DEVICE: bool = False + + +def set_propogate_device(val: bool) -> None: + global PROPOGATE_DEVICE + PROPOGATE_DEVICE = val + + +def get_propogate_device() -> bool: + global PROPOGATE_DEVICE + return PROPOGATE_DEVICE diff --git a/torchrec/distributed/grouped_position_weighted.py b/torchrec/distributed/grouped_position_weighted.py index f6099ab29..bd9f92248 100644 --- a/torchrec/distributed/grouped_position_weighted.py +++ b/torchrec/distributed/grouped_position_weighted.py @@ -5,8 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import torch import torch.nn as nn @@ -14,6 +16,7 @@ from torchrec.distributed.utils import append_prefix from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + # Will be deprecated soon, please use PositionWeightedProcessor, see the full # doc under modules/feature_processor.py class GroupedPositionWeightedModule(BaseGroupedFeatureProcessor): @@ -67,6 +70,11 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: offsets=features.offsets(), stride=features.stride(), length_per_key=features.length_per_key(), + stride_per_key_per_rank=( + features.stride_per_key_per_rank() + if features.variable_stride_per_key() + else None + ), ) def named_parameters( @@ -96,9 +104,3 @@ def state_dict( param if keep_vars else param.detach() ) return destination - - def sparse_grad_parameter_names( - self, destination: Optional[List[str]] = None, prefix: str = "" - ) -> List[str]: - destination = [] if destination is None else destination - return destination diff --git a/torchrec/distributed/infer_utils.py b/torchrec/distributed/infer_utils.py new file mode 100644 index 000000000..68dd1e567 --- /dev/null +++ b/torchrec/distributed/infer_utils.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import cast, Dict, List, Optional, Set, Tuple, Type + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollection, + ShardedQuantEmbeddingCollection, +) + +from torchrec.distributed.quant_embeddingbag import ( + QuantEmbeddingBagCollection, + ShardedQuantEmbeddingBagCollection, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) + + +def get_tbes_from_sharded_module( + module: torch.nn.Module, +) -> List[IntNBitTableBatchedEmbeddingBagsCodegen]: + assert type(module) in [ + ShardedQuantEmbeddingBagCollection, + ShardedQuantEmbeddingCollection, + ], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBEs" + tbes = [] + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a function. + for lookup in module._lookups: + for lookup_per_rank in lookup._embedding_lookups_per_rank: + for emb_module in lookup_per_rank._emb_modules: + tbes.append(emb_module._emb_module) + return tbes + + +def get_tbe_specs_from_sharded_module( + module: torch.nn.Module, +) -> List[ + Tuple[str, int, int, str, str] +]: # # tuple of (feature_names, rows, dims, str(SparseType), str(EmbeddingLocation/placement)) + assert type(module) in [ + ShardedQuantEmbeddingBagCollection, + ShardedQuantEmbeddingCollection, + ], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBE specs" + tbe_specs = [] + tbes = get_tbes_from_sharded_module(module) + for tbe in tbes: + for spec in tbe.embedding_specs: + tbe_specs.append( + ( + spec[0], + spec[1], + spec[2], + str(spec[3]), + str(spec[4]), + ) + ) + return tbe_specs + + +def get_path_device_tuples( + module: object, ignore_list: Optional[List[str]] = None +) -> List[Tuple[str, str]]: + path_device_tuples: List[Tuple[str, str]] = [] + visited_path: Set[str] = set() + + cur_ignore_list: List[str] = ignore_list if ignore_list else ["embedding_tables"] + + def recursive_find_device( + module: object, cur_depth: int, path: str = "", max_depth: int = 50 + ) -> None: + nonlocal path_device_tuples + nonlocal visited_path + + if cur_depth > max_depth: + return + + if path in visited_path: + return + + visited_path.add(path) + if ( + isinstance(module, (int, float, str, bool, torch.Tensor)) + or type(module).__name__ in ["method", "function", "Proxy"] + or module is None + ): + return + + device_attrs = ("device", "_device", "_device_str", "_device_type") + + for name in dir(module): + if name in cur_ignore_list: + continue + child = getattr(module, name) + if name.startswith("__"): + continue + if name in device_attrs: + device = getattr(module, name) + path_device_tuples.append((path + "." + name, str(device))) + elif isinstance(child, list): + for idx, child_ in enumerate(child): + recursive_find_device( + child_, + cur_depth + 1, + f"{path}.{name}[{idx}]", + max_depth=max_depth, + ) + elif isinstance(child, dict): + for key, child_ in child.items(): + recursive_find_device( + child_, + cur_depth + 1, + f"{path}.{name}[{key}]", + max_depth=max_depth, + ) + else: + recursive_find_device( + child, cur_depth + 1, f"{path}.{name}", max_depth=max_depth + ) + + recursive_find_device(module, 0, "") + + return path_device_tuples + + +def get_all_torchrec_modules( + model: torch.nn.Module, + trec_module_class_types: Optional[List[Type[torch.nn.Module]]] = None, +) -> Dict[str, torch.nn.Module]: + """ + Get all targeted TorchRec modules in the model. + Args: + model (torch.nn.Module): The input module to search for TREC modules. + trec_module_class_types (List[Type[torch.nn.Module]], optional): List of type of Trec modules + """ + if not trec_module_class_types: + trec_module_class_types = [ + ShardedQuantEmbeddingBagCollection, + ShardedQuantEmbeddingCollection, + QuantEmbeddingBagCollection, + QuantEmbeddingCollection, + EmbeddingBagCollection, + EmbeddingCollection, + ] + trec_modules: Dict[str, torch.nn.Module] = {} + + def _recursive_get_module( + module: torch.nn.Module, + path: str, + target_module_class: List[Type[torch.nn.Module]], + ) -> None: + if type(module) in target_module_class: + trec_modules[path] = module + return + for name, c_module in module.named_children(): + child_path = f"{path}.{name}" if path else name + if type(c_module) in target_module_class: + trec_modules[child_path] = c_module + else: + _recursive_get_module(c_module, child_path, target_module_class) + + _recursive_get_module( + model, "", cast(List[Type[torch.nn.Module]], trec_module_class_types) + ) + + return trec_modules + + +def get_non_scriptable_trec_module( + model: torch.nn.Module, +) -> Dict[str, torch.nn.Module]: + """ + Get all targeted TorchRec modules in that model that is not torchsciptable before trace. + Args: + model (torch.nn.Module): The input module to search for TREC modules. + """ + return get_all_torchrec_modules( + model, + trec_module_class_types=[ + ShardedQuantEmbeddingBagCollection, + ShardedQuantEmbeddingCollection, + ], + ) diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py new file mode 100644 index 000000000..7250077a4 --- /dev/null +++ b/torchrec/distributed/itep_embeddingbag.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, OrderedDict, Tuple, Type, Union + +import torch +from torch import nn +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.embedding import ( + EmbeddingCollectionContext, + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + KJTList, + ShardedEmbeddingModule, +) +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionContext, + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, + ShardingType, +) +from torchrec.distributed.utils import filter_state_dict +from torchrec.modules.itep_embedding_modules import ( + ITEPEmbeddingBagCollection, + ITEPEmbeddingCollection, +) +from torchrec.modules.itep_modules import GenericITEPModule, RowwiseShardedITEPModule +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +@dataclass +class ITEPEmbeddingBagCollectionContext(EmbeddingBagCollectionContext): + is_reindexed: bool = False + + +class ShardingTypeGroup(Enum): + CW_GROUP = "column_wise_group" + RW_GROUP = "row_wise_group" + + +SHARDING_TYPE_TO_GROUP: Dict[str, ShardingTypeGroup] = { + ShardingType.ROW_WISE.value: ShardingTypeGroup.RW_GROUP, + ShardingType.TABLE_ROW_WISE.value: ShardingTypeGroup.RW_GROUP, + ShardingType.COLUMN_WISE.value: ShardingTypeGroup.CW_GROUP, + ShardingType.TABLE_WISE.value: ShardingTypeGroup.CW_GROUP, +} + + +class ShardedITEPEmbeddingBagCollection( + ShardedEmbeddingModule[ + KJTList, + List[torch.Tensor], + KeyedTensor, + ITEPEmbeddingBagCollectionContext, + ] +): + def __init__( + self, + module: ITEPEmbeddingBagCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ebc_sharder: EmbeddingBagCollectionSharder, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + + self._device = device + self._env = env + + # Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")) + ) + + self._embedding_bag_collection: ShardedEmbeddingBagCollection = ( + ebc_sharder.shard( + module._embedding_bag_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + ) + + self.table_name_to_sharding_type: Dict[str, str] = {} + for table_name in table_name_to_parameter_sharding.keys(): + self.table_name_to_sharding_type[table_name] = ( + table_name_to_parameter_sharding[table_name].sharding_type + ) + + # Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules + (grouped_lookups, grouped_table_unpruned_size_map) = ( + self._group_lookups_and_table_unpruned_size_map( + module._itep_module.table_name_to_unpruned_hash_sizes, + ) + ) + + # Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case + self._itep_module: GenericITEPModule = GenericITEPModule( + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ + ShardingTypeGroup.CW_GROUP + ], + lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP], + pruning_interval=module._itep_module.pruning_interval, + enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, + ) + self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule( + table_name_to_sharding_type=self.table_name_to_sharding_type, + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ + ShardingTypeGroup.RW_GROUP + ], + lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP], + pruning_interval=module._itep_module.pruning_interval, + enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, + ) + + def prefetch( + self, + dist_input: KJTList, + forward_stream: Optional[Union[torch.cuda.Stream, torch.mtia.Stream]] = None, + ctx: Optional[ITEPEmbeddingBagCollectionContext] = None, + ) -> None: + assert ( + ctx is not None + ), "ITEP Prefetch call requires ITEPEmbeddingBagCollectionContext" + dist_input = self._reindex(dist_input) + ctx.is_reindexed = True + self._embedding_bag_collection.prefetch(dist_input, forward_stream, ctx) + + # pyre-ignore + def input_dist( + self, + ctx: ITEPEmbeddingBagCollectionContext, + features: KeyedJaggedTensor, + force_insert: bool = False, + ) -> Awaitable[Awaitable[KJTList]]: + return self._embedding_bag_collection.input_dist(ctx, features) + + def _reindex(self, dist_input: KJTList) -> KJTList: + for i, (sharding, features) in enumerate( + zip( + self._embedding_bag_collection._sharding_types, + dist_input, + ) + ): + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: + remapped_kjt = self._itep_module(features, self._iter.item()) + else: + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) + dist_input[i] = remapped_kjt + return dist_input + + def compute( + self, + ctx: ITEPEmbeddingBagCollectionContext, + dist_input: KJTList, + ) -> List[torch.Tensor]: + if not ctx.is_reindexed: + dist_input = self._reindex(dist_input) + ctx.is_reindexed = True + + self._iter += 1 + return self._embedding_bag_collection.compute(ctx, dist_input) + + def output_dist( + self, + ctx: ITEPEmbeddingBagCollectionContext, + output: List[torch.Tensor], + ) -> LazyAwaitable[KeyedTensor]: + + ebc_awaitable = self._embedding_bag_collection.output_dist(ctx, output) + return ebc_awaitable + + def compute_and_output_dist( + self, ctx: ITEPEmbeddingBagCollectionContext, input: KJTList + ) -> LazyAwaitable[KeyedTensor]: + # Insert forward() function of GenericITEPModule into compute_and_output_dist() + for i, (sharding, features) in enumerate( + zip( + self._embedding_bag_collection._sharding_types, + input, + ) + ): + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: + remapped_kjt = self._itep_module(features, self._iter.item()) + else: + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) + input[i] = remapped_kjt + self._iter += 1 + ebc_awaitable = self._embedding_bag_collection.compute_and_output_dist( + ctx, input + ) + return ebc_awaitable + + def create_context(self) -> ITEPEmbeddingBagCollectionContext: + return ITEPEmbeddingBagCollectionContext() + + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + self._iter = state_dict["_iter"] + for name, child_module in self._modules.items(): + if child_module is not None: + missing, unexpected = child_module.load_state_dict( + filter_state_dict(state_dict, name), + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def _group_lookups_and_table_unpruned_size_map( + self, table_name_to_unpruned_hash_sizes: Dict[str, int] + ) -> Tuple[ + Dict[ShardingTypeGroup, List[nn.Module]], + Dict[ShardingTypeGroup, Dict[str, int]], + ]: + """ + Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types. + CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP. + + Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map) + """ + grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list) + grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = ( + defaultdict(dict) + ) + for sharding_type, lookup in zip( + self._embedding_bag_collection._sharding_types, + self._embedding_bag_collection._lookups, + ): + sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type] + # group lookups + grouped_lookups[sharding_group].append(lookup) + # group table_name_to_unpruned_hash_sizes + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb_config in lookup.grouped_configs: + for table in emb_config.embedding_tables: + if table.name in table_name_to_unpruned_hash_sizes.keys(): + grouped_table_unpruned_size_map[sharding_group][table.name] = ( + table_name_to_unpruned_hash_sizes[table.name] + ) + + return grouped_lookups, grouped_table_unpruned_size_map + + +class ITEPEmbeddingBagCollectionSharder( + BaseEmbeddingSharder[ITEPEmbeddingBagCollection] +): + def __init__( + self, + ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._ebc_sharder: EmbeddingBagCollectionSharder = ( + ebc_sharder + or EmbeddingBagCollectionSharder( + qcomm_codecs_registry=self.qcomm_codecs_registry + ) + ) + + def shard( + self, + module: ITEPEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedITEPEmbeddingBagCollection: + + # Enforce GPU for ITEPEmbeddingBagCollection + if device is None: + device = torch.device("cuda") + + return ShardedITEPEmbeddingBagCollection( + module, + params, + ebc_sharder=self._ebc_sharder, + env=env, + device=device, + ) + + def shardable_parameters( + self, module: ITEPEmbeddingBagCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._ebc_sharder.shardable_parameters(module._embedding_bag_collection) + + @property + def module_type(self) -> Type[ITEPEmbeddingBagCollection]: + return ITEPEmbeddingBagCollection + + def sharding_types(self, compute_device_type: str) -> List[str]: + types = list(SHARDING_TYPE_TO_GROUP.keys()) + return types + + +class ITEPEmbeddingCollectionContext(EmbeddingCollectionContext): + + def __init__(self) -> None: + super().__init__() + self.is_reindexed: bool = False + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {} + + +class ShardedITEPEmbeddingCollection( + ShardedEmbeddingModule[ + KJTList, + List[torch.Tensor], + Dict[str, JaggedTensor], + ITEPEmbeddingCollectionContext, + ] +): + def __init__( + self, + module: ITEPEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ebc_sharder: EmbeddingCollectionSharder, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + + self._device = device + self._env = env + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( + module._itep_module.table_name_to_unpruned_hash_sizes + ) + + # Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")) + ) + + self._embedding_collection: ShardedEmbeddingCollection = ebc_sharder.shard( + module._embedding_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + + self.table_name_to_sharding_type: Dict[str, str] = {} + for table_name in table_name_to_parameter_sharding.keys(): + self.table_name_to_sharding_type[table_name] = ( + table_name_to_parameter_sharding[table_name].sharding_type + ) + + # Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules + (grouped_lookups, grouped_table_unpruned_size_map) = ( + self._group_lookups_and_table_unpruned_size_map( + module._itep_module.table_name_to_unpruned_hash_sizes, + ) + ) + + # Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case + self._itep_module: GenericITEPModule = GenericITEPModule( + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ + ShardingTypeGroup.CW_GROUP + ], + lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP], + pruning_interval=module._itep_module.pruning_interval, + enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, + ) + self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule( + table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ + ShardingTypeGroup.RW_GROUP + ], + lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP], + pruning_interval=module._itep_module.pruning_interval, + table_name_to_sharding_type=self.table_name_to_sharding_type, + enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, + ) + + # pyre-ignore + def input_dist( + self, + ctx: ITEPEmbeddingCollectionContext, + features: KeyedJaggedTensor, + force_insert: bool = False, + ) -> Awaitable[Awaitable[KJTList]]: + + ctx.table_name_to_unpruned_hash_sizes = self.table_name_to_unpruned_hash_sizes + return self._embedding_collection.input_dist(ctx, features) + + def compute( + self, + ctx: ITEPEmbeddingCollectionContext, + dist_input: KJTList, + ) -> List[torch.Tensor]: + for i, (sharding, features) in enumerate( + zip( + self._embedding_collection._sharding_type_to_sharding.keys(), + dist_input, + ) + ): + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: + remapped_kjt = self._itep_module(features, self._iter.item()) + else: + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) + dist_input[i] = remapped_kjt + self._iter += 1 + return self._embedding_collection.compute(ctx, dist_input) + + def output_dist( + self, + ctx: ITEPEmbeddingCollectionContext, + output: List[torch.Tensor], + ) -> LazyAwaitable[Dict[str, JaggedTensor]]: + + ec_awaitable = self._embedding_collection.output_dist(ctx, output) + return ec_awaitable + + def compute_and_output_dist( + self, ctx: ITEPEmbeddingCollectionContext, input: KJTList + ) -> LazyAwaitable[Dict[str, JaggedTensor]]: + # Insert forward() function of GenericITEPModule into compute_and_output_dist() + """ """ + for i, (sharding, features) in enumerate( + zip( + self._embedding_collection._sharding_type_to_sharding.keys(), + input, + ) + ): + if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP: + remapped_kjt = self._itep_module(features, self._iter.item()) + else: + remapped_kjt = self._rowwise_itep_module(features, self._iter.item()) + input[i] = remapped_kjt + self._iter += 1 + ec_awaitable = self._embedding_collection.compute_and_output_dist(ctx, input) + return ec_awaitable + + def create_context(self) -> ITEPEmbeddingCollectionContext: + return ITEPEmbeddingCollectionContext() + + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` + # inconsistently. + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + self._iter = state_dict["_iter"] + for name, child_module in self._modules.items(): + if child_module is not None: + missing, unexpected = child_module.load_state_dict( + filter_state_dict(state_dict, name), + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def _group_lookups_and_table_unpruned_size_map( + self, table_name_to_unpruned_hash_sizes: Dict[str, int] + ) -> Tuple[ + Dict[ShardingTypeGroup, List[nn.Module]], + Dict[ShardingTypeGroup, Dict[str, int]], + ]: + """ + Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types. + CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP. + + Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map) + """ + grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list) + grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = ( + defaultdict(dict) + ) + for sharding_type, lookup in zip( + self._embedding_collection._sharding_types, + self._embedding_collection._lookups, + ): + sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type] + # group lookups + grouped_lookups[sharding_group].append(lookup) + # group table_name_to_unpruned_hash_sizes + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb_config in lookup.grouped_configs: + for table in emb_config.embedding_tables: + if table.name in table_name_to_unpruned_hash_sizes.keys(): + grouped_table_unpruned_size_map[sharding_group][table.name] = ( + table_name_to_unpruned_hash_sizes[table.name] + ) + + return grouped_lookups, grouped_table_unpruned_size_map + + +class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]): + def __init__( + self, + ebc_sharder: Optional[EmbeddingCollectionSharder] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._ebc_sharder: EmbeddingCollectionSharder = ( + ebc_sharder + or EmbeddingCollectionSharder( + qcomm_codecs_registry=self.qcomm_codecs_registry + ) + ) + + def shard( + self, + module: ITEPEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedITEPEmbeddingCollection: + + # Enforce GPU for ITEPEmbeddingBagCollection + if device is None: + device = torch.device("cuda") + + return ShardedITEPEmbeddingCollection( + module, + params, + ebc_sharder=self._ebc_sharder, + env=env, + device=device, + ) + + def shardable_parameters( + self, module: ITEPEmbeddingCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._ebc_sharder.shardable_parameters(module._embedding_collection) + + @property + def module_type(self) -> Type[ITEPEmbeddingCollection]: + return ITEPEmbeddingCollection + + def sharding_types(self, compute_device_type: str) -> List[str]: + types = list(SHARDING_TYPE_TO_GROUP.keys()) + return types diff --git a/torchrec/distributed/keyed_jagged_tensor_pool.py b/torchrec/distributed/keyed_jagged_tensor_pool.py new file mode 100644 index 000000000..3c605c943 --- /dev/null +++ b/torchrec/distributed/keyed_jagged_tensor_pool.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import cast, Dict, List, Optional, Tuple, Type, Union + +import torch +from torchrec.distributed.object_pool import ShardedObjectPool +from torchrec.distributed.sharding.rw_kjt_pool_sharding import ( + InferRwKeyedJaggedTensorPoolOutputDist, + InferRwKeyedJaggedTensorPoolSharding, + KeyedJaggedTensorPoolRwReplicatedSharding, + KeyedJaggedTensorPoolRwSharding, +) +from torchrec.distributed.sharding.rw_pool_sharding import InferRwObjectPoolInputDist +from torchrec.distributed.tensor_sharding import ObjectPoolShardingContext + +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ModuleSharder, + ObjectPoolShardingPlan, + ObjectPoolShardingType, + ShardingEnv, +) +from torchrec.modules.keyed_jagged_tensor_pool import KeyedJaggedTensorPool +from torchrec.modules.object_pool_lookups import ( + KeyedJaggedTensorPoolLookup, + TensorJaggedIndexSelectLookup, + UVMCachingInt32Lookup, + UVMCachingInt64Lookup, +) +from torchrec.modules.utils import deterministic_dedup, jagged_index_select_with_empty +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +class KeyedJaggedTensorPoolAwaitable(LazyAwaitable[KeyedJaggedTensor]): + def __init__( + self, + awaitable: Awaitable[JaggedTensor], + keys: List[str], + device: torch.device, + unbucketize_permute: torch.Tensor, + ) -> None: + super().__init__() + self._awaitable = awaitable + self._unbucketize_permute = unbucketize_permute + self._keys = keys + self._device = device + + def _wait_impl(self) -> KeyedJaggedTensor: + # we could un-permute, but perhaps unnecessary for caching use case. + jt = self._awaitable.wait() + + # if we have an empty found, we should not perform any lookups and just + # return an empty KJT. + if jt.lengths().size()[0] == 0: + return KeyedJaggedTensor.empty( + is_weighted=jt.weights_or_none() is not None, + device=self._device, + values_dtype=jt.values().dtype, + lengths_dtype=jt.lengths().dtype, + weights_dtype=getattr(jt.weights_or_none(), "dtype", None), + ) + + """ + We need to permute the row order KJT based on the unbucketize permute tensor + to respect the original order that it came in. + """ + + unbucketize_id_permute = ( + torch.arange(jt.lengths().shape[0], device=self._device) + .view(-1, len(self._keys))[self._unbucketize_permute] + .flatten() + ) + + """ + Since the all to all will return to us in a row manner format, we need to regroup + using jaggeed index_select to key order. + For example, we will receive 0,2,3,4,5,6,7. But we need it to be in [0,2,5,6,3,4,7] order. + + F1 F2 + [0,2] . [3,4] + [5,6] [7] + + Can remove if we can write efficient kernel that can return in feature order. This would also + require splits to be transposed and flattened, to be put in feature order. + """ + + row_major_to_feature_major_permute = ( + torch.arange(jt.lengths().shape[0], device=self._device) + .view(-1, len(self._keys)) + .t() + .flatten() + ) + """ + The below is equivalent to doing + reorder_v = jagged_index_select(values, unbucketize_id_permute) + reorder_v = jagged_index_select(reorder_v, row_major_to_feature_major_permute) + """ + + indices = unbucketize_id_permute[row_major_to_feature_major_permute] + reorder_l = jt.lengths()[indices] + reorder_o = torch.ops.fbgemm.asynchronous_inclusive_cumsum(reorder_l) + reorder_v = jagged_index_select_with_empty( + jt.values().unsqueeze(-1), indices, jt.offsets()[1:], reorder_o + ) + + reorder_w = ( + jagged_index_select_with_empty( + jt.weights().unsqueeze(-1), + indices, + jt.offsets()[1:], + reorder_o, + ) + if jt.weights_or_none() is not None + else None + ) + + return KeyedJaggedTensor( + keys=self._keys, + values=reorder_v.flatten(), + weights=reorder_w.flatten() if reorder_w is not None else None, + lengths=reorder_l, + ) + + +class ShardedKeyedJaggedTensorPool( + ShardedObjectPool[ + KeyedJaggedTensor, # Out + JaggedTensor, # DistOut + ObjectPoolShardingContext, # Ctx + ] +): + """ + Sharded implementation of `KeyedJaggedTensorPool` + + When dealing with a large pool that cannot fit in a single device memory + (i.e. HBM / UVM / CPU etc), this module handles sharding the pool row-wise, including + orchestrating the communication between ranks for distributed lookup and update. + + Args: + pool_size (int): total number of batches that can be stored in the pool + values_dtype (torch.dtype): dtype of the KJT values in the pool + feature_max_lengths (Dict[str,int]): Mapping from feature name in KJT + to the maximum size of the jagged slices for the feature. + is_weighted (bool): whether KJT values have weights that need to be stored. + sharding_env (ShardingEnv): sharding environment (e.g. world_size, ranks, etc) + sharding_plan (ObjectPoolShardingPlan): info about sharding strategy + device (Optional[torch.device]): default device + enable_uvm (bool): if set to true, the pool will be allocated on UVM + + Example:: + # Example on 2 GPUs + # on rank 0, update ids [2,0] with values + # ids f1 f2 + # 2 [1] [2, 3] + # 0 [4,5] [6] + sharded_keyed_jagged_tensor_pool.update( + ids=torch.Tensor([2,0],dtype=torch.int,device="cuda:0") + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1","f2"], + values=torch.Tensor([1,2,3,4,5,6],device="cuda:0"), + lengths=torch.Tensor([1,2,2,1],device="cuda:0") + ) + ) + + # on rank 1, update ids [1,3] with values + # ids f1 f2 + # 1 [7,8] [] + # 3 [9,10,11] [12] + sharded_keyed_jagged_tensor_pool.update( + ids=torch.Tensor([1,3],dtype=torch.int,device="cuda:1") + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1","f2"], + values=torch.Tensor([7,8,9,10,11,12],device="cuda:1"), + lengths=torch.Tensor([2,0,3,1],device="cuda:1") + ) + ) + + # At this point the global state is: + # ids f1 f2 + # 0 [2,3] [6] <- rank 0 + # 1 [7,8] [9,10,11] <- rank 0 + # 2 [1] [4,5] <- rank 1 + # 3 [] [12] <- rank 1 + + """ + + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_env: ShardingEnv, + sharding_plan: ObjectPoolShardingPlan, + device: Optional[torch.device] = None, + # TODO add quantized comms codec registry + enable_uvm: bool = False, + ) -> None: + + super().__init__() + self._pool_size = pool_size + self._values_dtype = values_dtype + self._sharding_env = sharding_env + self._device: torch.device = device or torch.device("cuda") + self._is_weighted = is_weighted + self._sharding_plan = sharding_plan + self._feature_max_lengths = feature_max_lengths + self.register_buffer( + "_feature_max_lengths_t", + torch.tensor( + list(feature_max_lengths.values()), + dtype=torch.int32, + device=self._device, + ), + persistent=False, + ) + self._features: List[str] = list(feature_max_lengths.keys()) + self._enable_uvm = enable_uvm + + # pyre-fixme[4]: Attribute must be annotated. + self._permute_feature = None + if sharding_plan.sharding_type == ObjectPoolShardingType.ROW_WISE: + self._sharding: KeyedJaggedTensorPoolRwSharding = ( + KeyedJaggedTensorPoolRwSharding( + env=self._sharding_env, + device=self._device, + pool_size=self._pool_size, + num_features=len(feature_max_lengths), + ) + ) + elif sharding_plan.sharding_type == ObjectPoolShardingType.REPLICATED_ROW_WISE: + self._sharding: KeyedJaggedTensorPoolRwReplicatedSharding = ( + KeyedJaggedTensorPoolRwReplicatedSharding( + env=self._sharding_env, + device=self._device, + pool_size=self._pool_size, + num_features=len(feature_max_lengths), + ) + ) + + else: + raise NotImplementedError( + f"Sharding type {self._sharding_plan.sharding_type} is not implemented" + ) + + # pyre-ignore + self._lookup: KeyedJaggedTensorPoolLookup = None + if self._enable_uvm: + if values_dtype == torch.int64: + self._lookup = UVMCachingInt64Lookup( + self._sharding.local_pool_size, + feature_max_lengths, + is_weighted, + self._device, + ) + if values_dtype == torch.int32: + self._lookup = UVMCachingInt32Lookup( + self._sharding.local_pool_size, + feature_max_lengths, + is_weighted, + self._device, + ) + else: + self._lookup = TensorJaggedIndexSelectLookup( + self._sharding.local_pool_size, + values_dtype, + feature_max_lengths, + is_weighted, + self._device, + ) + if self._lookup is None: + raise ValueError( + f"Cannot create lookup for {self._enable_uvm=} {self._values_dtype}" + ) + + for fqn, tensor in self._lookup.states_to_register(): + self.register_buffer( + fqn, + tensor, + ) + + # pyre-ignore + self._lookup_ids_dist_impl = self._sharding.create_lookup_ids_dist() + # pyre-ignore + self._lookup_values_dist_impl = self._sharding.create_lookup_values_dist() + # pyre-ignore + self._update_ids_dist_impl = self._sharding.create_update_ids_dist() + # pyre-ignore + self._update_values_dist_impl = self._sharding.create_update_values_dist() + + self._initialize_torch_state(self._lookup, sharding_plan.sharding_type) + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def feature_max_lengths(self) -> Dict[str, int]: + return self._feature_max_lengths + + @property + def values_dtype(self) -> torch.dtype: + return self._values_dtype + + @property + def is_weighted(self) -> bool: + return self._is_weighted + + @property + def device(self) -> torch.device: + torch._assert(self._device is not None, "self._device should already be set") + return self._device + + def _initialize_torch_state( + self, lookup: KeyedJaggedTensorPoolLookup, sharding_type: ObjectPoolShardingType + ) -> None: + for fqn, tensor in self._sharding.get_sharded_states_to_register(self._lookup): + self.register_buffer(fqn, tensor) + # somewhat hacky. ideally, we should be able to invoke this method on + # any update to the lookup's key_lengths field. + lengths, offsets = lookup._infer_jagged_lengths_inclusive_offsets() + # pyre-fixme[16]: `KeyedJaggedTensorPoolLookup` has no attribute `_lengths`. + lookup._lengths = lengths + # pyre-fixme[16]: `KeyedJaggedTensorPoolLookup` has no attribute `_offsets`. + lookup._offsets = offsets + + def _lookup_ids_dist( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: + return self._lookup_ids_dist_impl(ctx=ctx, ids=ids) + + def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor: + """ + 1. Permute/filter KJT keys to be the same as in feature_max_lengths + 2. Ensure the max_lengths of input is within the feature_max_lengths + """ + if self._permute_feature is None: + self._permute_feature = [] + for feature in self._feature_max_lengths.keys(): + for j, kjt_feature in enumerate(values.keys()): + if feature == kjt_feature: + self._permute_feature.append(j) + + valid_input = values.permute(self._permute_feature) + # can disable below check if expensive + max_elements, _max_indices = ( + valid_input.lengths() + .reshape(len(self._feature_max_lengths.keys()), -1) + .max(dim=1) + ) + + assert torch.all( + max_elements <= self._feature_max_lengths_t + ).item(), "input KJT has a feature that exceeds specified max lengths" + + return valid_input + + def _update_local( + self, + ctx: ObjectPoolShardingContext, + ids: torch.Tensor, + values: JaggedTensor, + ) -> None: + if ids.size(0) == 0: + return + jt = values + deduped_ids, dedup_permutation = deterministic_dedup(ids) + + device = ids.device + arange_idx = torch.arange(len(jt.lengths()), device=device) + value_dedup_permute = (arange_idx.view(-1, len(self._feature_max_lengths)))[ + dedup_permutation, : + ].flatten() + + deduped_lengths = jt.lengths()[value_dedup_permute] + deduped_offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum( + deduped_lengths + ) + deduped_values = jagged_index_select_with_empty( + jt.values().unsqueeze(-1), + value_dedup_permute, + jt.offsets()[1:], + deduped_offsets, + ) + + deduped_values, deduped_lengths = ( + deduped_values.flatten(), + deduped_lengths.flatten(), + ) + + deduped_weights = None + if jt.weights_or_none() is not None: + deduped_weights = jagged_index_select_with_empty( + jt.weights().unsqueeze(-1), + value_dedup_permute, + jt.offsets()[1:], + deduped_offsets, + ) + deduped_weights = deduped_weights.flatten() + + self._lookup.update( + deduped_ids, + JaggedTensor( + values=deduped_values, + lengths=deduped_lengths, + weights=deduped_weights, + ), + ) + + def _lookup_local( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> JaggedTensor: + return self._lookup.lookup(ids) + + def _lookup_values_dist( + self, + ctx: ObjectPoolShardingContext, + values: JaggedTensor, + ) -> LazyAwaitable[KeyedJaggedTensor]: + return KeyedJaggedTensorPoolAwaitable( + awaitable=self._lookup_values_dist_impl(ctx, values), + unbucketize_permute=ctx.unbucketize_permute, + keys=self._features, + device=self._device, + ) + + def _update_ids_dist( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: + return self._update_ids_dist_impl(ctx=ctx, ids=ids) + + def _update_values_dist( + self, ctx: ObjectPoolShardingContext, values: KeyedJaggedTensor + ) -> Awaitable[JaggedTensor]: + return self._update_values_dist_impl(values, ctx) + + def create_context(self) -> ObjectPoolShardingContext: + return cast(ObjectPoolShardingContext, self._sharding.create_context()) + + +@torch.fx.wrap +def _get_reorder_values_lengths_weights( + keys: List[str], + jt: JaggedTensor, + # not actually optional, just making torchscript type happy. + unbucketize_permute: Optional[torch.Tensor], + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + unbucketize_id_permute = ( + torch.arange(jt.lengths().shape[0], device=device) + .view(-1, len(keys))[unbucketize_permute] + .flatten() + ) + row_major_to_feature_major_permute = ( + torch.arange(jt.lengths().shape[0], device=device) + .view(-1, len(keys)) + .t() + .flatten() + ) + indices = unbucketize_id_permute[row_major_to_feature_major_permute] + reorder_l = jt.lengths()[indices] + reorder_o = torch.ops.fbgemm.asynchronous_inclusive_cumsum(reorder_l) + reorder_v = jagged_index_select_with_empty( + jt.values().unsqueeze(-1), indices, jt.offsets()[1:], reorder_o + ) + reorder_w = ( + jagged_index_select_with_empty( + jt.weights().unsqueeze(-1), + indices, + jt.offsets()[1:], + reorder_o, + ).flatten() + if jt.weights_or_none() is not None + else None + ) + + return (reorder_v.flatten(), reorder_l.flatten(), reorder_w) + + +class ShardedInferenceKeyedJaggedTensorPool( + ShardedObjectPool[KeyedJaggedTensor, List[torch.Tensor], ObjectPoolShardingContext], +): + _local_kjt_pool_shards: torch.nn.ModuleList + _world_size: int + _device: torch.device + + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_env: ShardingEnv, + sharding_plan: ObjectPoolShardingPlan, + module: KeyedJaggedTensorPool, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + self._pool_size = pool_size + self._values_dtype = values_dtype + self._sharding_env = sharding_env + self._world_size = self._sharding_env.world_size + self._device = device or torch.device("cuda") + self._sharding_plan = sharding_plan + + self._is_weighted = is_weighted + self._feature_max_lengths = feature_max_lengths + + torch._assert( + self._sharding_plan.inference, "Plan needs to have inference enabled" + ) + + if self._sharding_plan.sharding_type == ObjectPoolShardingType.ROW_WISE: + # pyre-fixme[4]: Attribute must be annotated. + self._sharding = InferRwKeyedJaggedTensorPoolSharding( + env=self._sharding_env, + device=self._device, + pool_size=self._pool_size, + ) + else: + raise NotImplementedError( + f"Sharding type {self._sharding_plan.sharding_type} is not implemented" + ) + + self._local_kjt_pool_shards = torch.nn.ModuleList() + offset = 0 + for rank, this_rank_size in zip( + range(self._world_size), self._sharding.local_pool_size_per_rank + ): + shard_device = ( + torch.device("cpu") + if device == torch.device("cpu") + else torch.device("cuda", rank) + ) + self._local_kjt_pool_shards.append( + TensorJaggedIndexSelectLookup( + this_rank_size, + self._values_dtype, + feature_max_lengths, + self._is_weighted, + shard_device, + ) + ) + if module._device != torch.device("meta"): + self._local_kjt_pool_shards[rank]._values.copy_( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + module.values[offset : offset + this_rank_size] + ) + self._local_kjt_pool_shards[rank]._key_lengths.copy_( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + module.key_lengths[offset : offset + this_rank_size] + ) + jagged_lengths, jagged_offsets = self._local_kjt_pool_shards[ + rank + ]._infer_jagged_lengths_inclusive_offsets() + self._local_kjt_pool_shards[rank]._jagged_lengths = jagged_lengths + self._local_kjt_pool_shards[rank]._jagged_offsets = jagged_offsets + offset += this_rank_size + + # TODO: move these to class type declarations + # this can be somewhat tricky w/ torchscript since these are + # abstract classes. + self._lookup_ids_dist_impl: InferRwObjectPoolInputDist = torch.jit.annotate( + InferRwObjectPoolInputDist, + self._sharding.create_lookup_ids_dist(), + ) + + self._lookup_values_dist_impl: InferRwKeyedJaggedTensorPoolOutputDist = ( + torch.jit.annotate( + InferRwKeyedJaggedTensorPoolOutputDist, + self._sharding.create_lookup_values_dist(), + ) + ) + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def dim(self) -> int: + # pyre-fixme[7]: Expected `int` but got `Union[Tensor, Module]`. + return self._dim + + @property + def dtype(self) -> torch.dtype: + return self._values_dtype + + @property + def device(self) -> torch.device: + torch._assert(self._device is not None, "self._device should already be set") + return self._device + + def create_context(self) -> ObjectPoolShardingContext: + raise NotImplementedError("create_context() is not implemented") + + # pyre-ignore + def _lookup_ids_dist( + self, + ids: torch.Tensor, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: + return self._lookup_ids_dist_impl(ids) + + # pyre-ignore + def _lookup_local( + self, + dist_input: List[torch.Tensor], + ) -> List[JaggedTensor]: + ret = torch.jit.annotate(List[JaggedTensor], []) + for i, shard in enumerate(self._local_kjt_pool_shards): + ret.append(shard(dist_input[i])) + return ret + + # pyre-ignore + def _lookup_values_dist( + self, + lookups: List[JaggedTensor], + ) -> JaggedTensor: + return self._lookup_values_dist_impl(lookups) + + # pyre-ignore + def forward(self, ids: torch.Tensor) -> KeyedJaggedTensor: + dist_input, unbucketize_permute = self._lookup_ids_dist(ids) + lookup = self._lookup_local(dist_input) + # Here we are playing a trick to workaround a fx tracing issue, + # as proxy is not iteratable. + lookup_list = [] + for i in range(self._world_size): + lookup_list.append(lookup[i]) + + jt = self._lookup_values_dist(lookup_list) + keys = list(self._feature_max_lengths.keys()) + reorder_v, reorder_l, reorder_w = _get_reorder_values_lengths_weights( + keys, jt, unbucketize_permute, self._device + ) + + ret = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=reorder_v, + weights=reorder_w, + lengths=reorder_l, + ) + return ret + + # pyre-ignore + def _update_ids_dist( + self, + ctx: ObjectPoolShardingContext, + ids: torch.Tensor, + ) -> None: + raise NotImplementedError("Inference does not currently support update") + + # pyre-ignore + def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor): + raise NotImplementedError("Inference does not currently support update") + + def _update_local( + self, + ctx: ObjectPoolShardingContext, + ids: torch.Tensor, + values: List[torch.Tensor], + ) -> None: + raise NotImplementedError("Inference does not support update") + + # pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value of + # `None`. + def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor: + pass + + +class KeyedJaggedTensorPoolSharder(ModuleSharder[KeyedJaggedTensorPool]): + def __init__(self) -> None: + super().__init__() + + def shard( + self, + module: KeyedJaggedTensorPool, + plan: ObjectPoolShardingPlan, + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> Union[ShardedKeyedJaggedTensorPool, ShardedInferenceKeyedJaggedTensorPool]: + if plan.inference: + return ShardedInferenceKeyedJaggedTensorPool( + pool_size=module.pool_size, + feature_max_lengths=module.feature_max_lengths, + values_dtype=module.values_dtype, + is_weighted=module.is_weighted, + sharding_env=env, + sharding_plan=plan, + module=module, + device=device, + ) + return ShardedKeyedJaggedTensorPool( + module.pool_size, + module.feature_max_lengths, + module.values_dtype, + module.is_weighted, + sharding_plan=plan, + sharding_env=env, + device=device, + enable_uvm=module._enable_uvm, + ) + + @property + def module_type(self) -> Type[KeyedJaggedTensorPool]: + return KeyedJaggedTensorPool diff --git a/torchrec/distributed/mc_embedding.py b/torchrec/distributed/mc_embedding.py new file mode 100644 index 000000000..0d939632e --- /dev/null +++ b/torchrec/distributed/mc_embedding.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Any, cast, Dict, List, Optional, Type + +import torch + +from torchrec.distributed.embedding import ( + EmbeddingCollectionContext, + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, + BaseShardedManagedCollisionEmbeddingCollection, +) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class ManagedCollisionEmbeddingCollectionContext(EmbeddingCollectionContext): + + def __init__( + self, + sharding_contexts: Optional[List[SequenceShardingContext]] = None, + input_features: Optional[List[KeyedJaggedTensor]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None, + remapped_kjt: Optional[KJTList] = None, + ) -> None: + super().__init__(sharding_contexts, input_features, reverse_indices) + self.evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = ( + evictions_per_table + ) + self.remapped_kjt: Optional[KJTList] = remapped_kjt + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + if self.evictions_per_table: + # pyre-ignore + for value in self.evictions_per_table.values(): + if value is None: + continue + value.record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + +class ShardedManagedCollisionEmbeddingCollection( + BaseShardedManagedCollisionEmbeddingCollection[ + ManagedCollisionEmbeddingCollectionContext + ] +): + def __init__( + self, + module: ManagedCollisionEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ec_sharder: EmbeddingCollectionSharder, + mc_sharder: ManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__( + module, + table_name_to_parameter_sharding, + ec_sharder, + mc_sharder, + env, + device, + ) + + # For consistency with embeddingbag + @property + def _embedding_collection(self) -> ShardedEmbeddingCollection: + return cast(ShardedEmbeddingCollection, self._embedding_module) + + def create_context( + self, + ) -> ManagedCollisionEmbeddingCollectionContext: + return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[]) + + +class ManagedCollisionEmbeddingCollectionSharder( + BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection] +): + def __init__( + self, + ec_sharder: Optional[EmbeddingCollectionSharder] = None, + mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__( + ec_sharder + or EmbeddingCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry, + fused_params=fused_params, + ), + mc_sharder or ManagedCollisionCollectionSharder(), + qcomm_codecs_registry=qcomm_codecs_registry, + ) + + def shard( + self, + module: ManagedCollisionEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedManagedCollisionEmbeddingCollection: + + if device is None: + device = torch.device("cuda") + + return ShardedManagedCollisionEmbeddingCollection( + module, + params, + # pyre-ignore [6] + ec_sharder=self._e_sharder, + mc_sharder=self._mc_sharder, + env=env, + device=device, + ) + + @property + def module_type(self) -> Type[ManagedCollisionEmbeddingCollection]: + return ManagedCollisionEmbeddingCollection diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py new file mode 100644 index 000000000..b817f020a --- /dev/null +++ b/torchrec/distributed/mc_embedding_modules.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar, Union + +import torch +from torch.autograd.profiler import record_function +from torchrec.distributed.embedding import ( + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + EmbeddingComputeKernel, + KJTList, + ShardedEmbeddingModule, +) +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ( + ManagedCollisionCollectionSharder, + ShardedManagedCollisionCollection, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + Multistreamable, + NoWait, + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollection, + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + +class BaseShardedManagedCollisionEmbeddingCollection( + ShardedEmbeddingModule[ + KJTList, + List[torch.Tensor], + Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]], + ShrdCtx, + ] +): + def __init__( + self, + module: Union[ + ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection + ], + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + e_sharder: Union[EmbeddingBagCollectionSharder, EmbeddingCollectionSharder], + mc_sharder: ManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + + self._device = device + self._env = env + + if isinstance(module, ManagedCollisionEmbeddingBagCollection): + assert isinstance(e_sharder, EmbeddingBagCollectionSharder) + assert isinstance(module._embedding_module, EmbeddingBagCollection) + self.bagged: bool = True + + self._embedding_module: ShardedEmbeddingBagCollection = e_sharder.shard( + module._embedding_module, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + else: + assert isinstance(e_sharder, EmbeddingCollectionSharder) + assert isinstance(module._embedding_module, EmbeddingCollection) + self.bagged: bool = False + + self._embedding_module: ShardedEmbeddingCollection = e_sharder.shard( + module._embedding_module, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + # TODO: This is a hack since _embedding_module doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + self._embedding_module._has_uninitialized_input_dist = False + embedding_shardings = ( + self._embedding_module._embedding_shardings + if isinstance(self._embedding_module, ShardedEmbeddingBagCollection) + else list(self._embedding_module._sharding_type_to_sharding.values()) + ) + self._managed_collision_collection: ShardedManagedCollisionCollection = ( + mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + embedding_shardings=embedding_shardings, + use_index_dedup=( + e_sharder._use_index_dedup + if isinstance(e_sharder, EmbeddingCollectionSharder) + else False + ), + ) + ) + self._return_remapped_features: bool = module._return_remapped_features + self._allow_in_place_embed_weight_update: bool = ( + module._allow_in_place_embed_weight_update + ) + + # pyre-ignore + self._table_to_tbe_and_index = {} + for lookup in self._embedding_module._lookups: + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not + # a function. + for emb_module in lookup._emb_modules: + for table_idx, table in enumerate(emb_module._config.embedding_tables): + self._table_to_tbe_and_index[table.name] = ( + emb_module._emb_module, + torch.tensor([table_idx], dtype=torch.int, device=self._device), + ) + self._buffer_ids: torch.Tensor = torch.tensor( + [0], device=self._device, dtype=torch.int + ) + + # pyre-ignore + def input_dist( + self, + ctx: ShrdCtx, + features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KJTList]]: + # TODO: resolve incompatiblity with different contexts + return self._managed_collision_collection.input_dist( + # pyre-fixme [6] + ctx, + features, + ) + + def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None: + open_slots = None + for table, evictions_indices_for_table in evictions_per_table.items(): + if evictions_indices_for_table is not None: + (tbe, logical_table_ids) = self._table_to_tbe_and_index[table] + pruned_indices_offsets = torch.tensor( + [0, evictions_indices_for_table.shape[0]], + dtype=torch.long, + device=self._device, + ) + if open_slots is None: + open_slots = self._managed_collision_collection.open_slots() + logger.info( + f"Table {table}: inserting {evictions_indices_for_table.numel()} ids with {open_slots[table].item()} open slots" + ) + with torch.no_grad(): + # embeddings, and optimizer state will be reset + tbe.reset_embedding_weight_momentum( + pruned_indices=evictions_indices_for_table.long(), + pruned_indices_offsets=pruned_indices_offsets, + logical_table_ids=logical_table_ids, + buffer_ids=self._buffer_ids, + ) + + if self.bagged: + table_weight_param = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | ModuleDict + # | Module` has no attribute `get_parameter`. + self._embedding_module.embedding_bags.get_parameter( + f"{table}.weight" + ) + ) + else: + table_weight_param = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | ModuleDict + # | Module` has no attribute `get_parameter`. + self._embedding_module.embeddings.get_parameter( + f"{table}.weight" + ) + ) + + init_fn = self._embedding_module._table_name_to_config[ + table + ].init_fn + # Set evicted indices to original init_fn instead of all zeros + if self._allow_in_place_embed_weight_update: + # In-place update with .data to bypass PyTorch's autograd tracking. + # This is required for model training with multiple forward passes where the autograd graph + # is already created. Direct tensor modification would trigger PyTorch's in-place operation + # checks and invalidate gradients, while .data allows safe reinitialization of evicted + # embeddings without affecting the computational graph. + # pyre-ignore [29] + table_weight_param.data[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) + else: + # pyre-ignore [29] + table_weight_param[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) + + def compute( + self, + ctx: ShrdCtx, + dist_input: KJTList, + ) -> List[torch.Tensor]: + with record_function("## compute:mcc ##"): + remapped_kjt = self._managed_collision_collection.compute( + # pyre-fixme [6] + ctx, + dist_input, + ) + evictions_per_table = self._managed_collision_collection.evict() + + self._evict(evictions_per_table) + ctx.remapped_kjt = remapped_kjt + ctx.evictions_per_table = evictions_per_table + + # pyre-ignore + return self._embedding_module.compute(ctx, remapped_kjt) + + # pyre-ignore + def output_dist( + self, + ctx: ShrdCtx, + output: List[torch.Tensor], + ) -> Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]]: + + # pyre-ignore [6] + ebc_awaitable = self._embedding_module.output_dist(ctx, output) + + if self._return_remapped_features: + kjt_awaitable = self._managed_collision_collection.output_dist( + # pyre-fixme [6] + ctx, + # pyre-ignore [16] + ctx.remapped_kjt, + ) + else: + kjt_awaitable = NoWait(None) + + # pyre-ignore + return ebc_awaitable, kjt_awaitable + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + yield append_prefix(prefix, fqn) + for fqn, _ in self.named_buffers(): + yield append_prefix(prefix, fqn) + + +M = TypeVar("M", bound=BaseManagedCollisionEmbeddingCollection) + + +class BaseManagedCollisionEmbeddingCollectionSharder(BaseEmbeddingSharder[M]): + def __init__( + self, + e_sharder: Union[EmbeddingBagCollectionSharder, EmbeddingCollectionSharder], + mc_sharder: ManagedCollisionCollectionSharder, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._e_sharder: Union[ + EmbeddingBagCollectionSharder, EmbeddingCollectionSharder + ] = e_sharder + self._mc_sharder: ManagedCollisionCollectionSharder = mc_sharder + + def shardable_parameters( + self, module: BaseManagedCollisionEmbeddingCollection + ) -> Dict[str, torch.nn.Parameter]: + # pyre-ignore + return self._e_sharder.shardable_parameters(module._embedding_module) + + def compute_kernels( + self, + sharding_type: str, + compute_device_type: str, + ) -> List[str]: + return [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + + def sharding_types(self, compute_device_type: str) -> List[str]: + return list( + set.intersection( + set(self._e_sharder.sharding_types(compute_device_type)), + set(self._mc_sharder.sharding_types(compute_device_type)), + ) + ) + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._e_sharder.fused_params diff --git a/torchrec/distributed/mc_embeddingbag.py b/torchrec/distributed/mc_embeddingbag.py new file mode 100644 index 000000000..e94d42d59 --- /dev/null +++ b/torchrec/distributed/mc_embeddingbag.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from dataclasses import dataclass +from typing import Any, cast, Dict, Optional, Type + +import torch +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionContext, + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, + BaseShardedManagedCollisionEmbeddingCollection, +) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection + + +@dataclass +class ManagedCollisionEmbeddingBagCollectionContext(EmbeddingBagCollectionContext): + evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None + remapped_kjt: Optional[KJTList] = None + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + if self.evictions_per_table: + # pyre-ignore + for value in self.evictions_per_table.values(): + if value is None: + continue + value.record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + +class ShardedManagedCollisionEmbeddingBagCollection( + BaseShardedManagedCollisionEmbeddingCollection[ + ManagedCollisionEmbeddingBagCollectionContext + ] +): + def __init__( + self, + module: ManagedCollisionEmbeddingBagCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ebc_sharder: EmbeddingBagCollectionSharder, + mc_sharder: ManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__( + module, + table_name_to_parameter_sharding, + ebc_sharder, + mc_sharder, + env, + device, + ) + + # For backwards compat, some references still to self._embedding_bag_collection + @property + def _embedding_bag_collection(self) -> ShardedEmbeddingBagCollection: + return cast(ShardedEmbeddingBagCollection, self._embedding_module) + + def create_context( + self, + ) -> ManagedCollisionEmbeddingBagCollectionContext: + return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[]) + + +class ManagedCollisionEmbeddingBagCollectionSharder( + BaseManagedCollisionEmbeddingCollectionSharder[ + ManagedCollisionEmbeddingBagCollection + ] +): + def __init__( + self, + ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, + mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__( + ebc_sharder + or EmbeddingBagCollectionSharder( + fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry + ), + mc_sharder or ManagedCollisionCollectionSharder(), + qcomm_codecs_registry=qcomm_codecs_registry, + ) + + def shard( + self, + module: ManagedCollisionEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedManagedCollisionEmbeddingBagCollection: + + if device is None: + device = torch.device("cuda") + + return ShardedManagedCollisionEmbeddingBagCollection( + module, + params, + # pyre-ignore [6] + ebc_sharder=self._e_sharder, + mc_sharder=self._mc_sharder, + env=env, + device=device, + ) + + @property + def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]: + return ManagedCollisionEmbeddingBagCollection diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py new file mode 100644 index 000000000..34e4ac672 --- /dev/null +++ b/torchrec/distributed/mc_modules.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import itertools +import logging +import math +from collections import defaultdict, OrderedDict +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union + +import torch +import torch.distributed as dist + +from torch import nn +from torch.distributed._shard.sharded_tensor import Shard, ShardMetadata + +from torchrec.distributed.embedding_sharding import ( + EmbeddingSharding, + EmbeddingShardingContext, + EmbeddingShardingInfo, + KJTListSplitsAwaitable, +) +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + GroupedEmbeddingConfig, + KJTList, + ListOfKJTList, +) + +from torchrec.distributed.sharding.rw_sequence_sharding import ( + RwSequenceEmbeddingDist, + RwSequenceEmbeddingSharding, +) +from torchrec.distributed.sharding.rw_sharding import ( + BaseRwEmbeddingSharding, + InferRwSparseFeaturesDist, + RwSparseFeaturesDist, +) +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + QuantizedCommCodecs, + ShardedModule, + ShardedTensor, + ShardingEnv, + ShardingType, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.mc_modules import ManagedCollisionCollection +from torchrec.modules.utils import construct_jagged_tensors +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.streamable import Multistreamable + + +@dataclass +class EmbeddingCollectionContext(Multistreamable): + sharding_contexts: List[ + Union[InferSequenceShardingContext, SequenceShardingContext] + ] + + def record_stream(self, stream: torch.Stream) -> None: + for ctx in self.sharding_contexts: + ctx.record_stream(stream) + + +class ManagedCollisionCollectionContext(EmbeddingCollectionContext): + pass + + +@torch.fx.wrap +def _fx_global_to_local_index( + feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] +) -> Dict[str, JaggedTensor]: + for feature, jt in feature_dict.items(): + jt._values = jt.values() - feature_to_offset[feature] + return feature_dict + + +@torch.fx.wrap +def _fx_jt_dict_add_offset( + feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] +) -> Dict[str, JaggedTensor]: + for feature, jt in feature_dict.items(): + jt._values = jt.values() + feature_to_offset[feature] + return feature_dict + + +@torch.fx.wrap +def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor: + return torch.tensor(kjt.length_per_key()) + + +logger: logging.Logger = logging.getLogger(__name__) + + +class ManagedCollisionCollectionAwaitable(LazyAwaitable[KeyedJaggedTensor]): + def __init__( + self, + awaitables_per_sharding: List[Awaitable[torch.Tensor]], + features_per_sharding: List[KeyedJaggedTensor], + embedding_names_per_sharding: List[List[str]], + need_indices: bool = False, + features_to_permute_indices: Optional[Dict[str, List[int]]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + ) -> None: + super().__init__() + self._awaitables_per_sharding = awaitables_per_sharding + self._features_per_sharding = features_per_sharding + self._need_indices = need_indices + self._features_to_permute_indices = features_to_permute_indices + self._embedding_names_per_sharding = embedding_names_per_sharding + self._reverse_indices = reverse_indices + + def _wait_impl(self) -> KeyedJaggedTensor: + jt_dict: Dict[str, JaggedTensor] = {} + for i, (w, f, e) in enumerate( + zip( + self._awaitables_per_sharding, + self._features_per_sharding, + self._embedding_names_per_sharding, + ) + ): + reverse_indices = ( + self._reverse_indices[i] if self._reverse_indices else None + ) + + jt_dict.update( + construct_jagged_tensors( + embeddings=w.wait(), + features=f, + embedding_names=e, + need_indices=self._need_indices, + features_to_permute_indices=self._features_to_permute_indices, + reverse_indices=reverse_indices, + ) + ) + # TODO: find better solution + for jt in jt_dict.values(): + jt._values = jt.values().flatten() + return KeyedJaggedTensor.from_jt_dict(jt_dict) + + +def create_mc_sharding( + sharding_type: str, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, +) -> EmbeddingSharding[ + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor +]: + if sharding_type == ShardingType.ROW_WISE.value: + return RwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + ) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + + +class ShardedManagedCollisionCollection( + ShardedModule[ + KJTList, + KJTList, + KeyedJaggedTensor, + ManagedCollisionCollectionContext, + ] +): + def __init__( + self, + module: ManagedCollisionCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + device: torch.device, + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + use_index_dedup: bool = False, + ) -> None: + super().__init__() + self.need_preprocess: bool = module.need_preprocess + self._device = device + self._env = env + self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( + copy.deepcopy(table_name_to_parameter_sharding) + ) + # TODO: create a MCSharding type instead of leveraging EmbeddingSharding + self._embedding_shardings = embedding_shardings + + self._embedding_names_per_sharding: List[List[str]] = [] + for sharding in self._embedding_shardings: + # TODO: support TWRW sharding + assert isinstance( + sharding, BaseRwEmbeddingSharding + ), "Only ROW_WISE sharding is supported." + self._embedding_names_per_sharding.append(sharding.embedding_names()) + + self._feature_to_table: Dict[str, str] = module._feature_to_table + self._table_to_features: Dict[str, List[str]] = module._table_to_features + self._has_uninitialized_input_dists: bool = True + self._input_dists: List[nn.Module] = [] + self._managed_collision_modules = nn.ModuleDict() + self._create_managed_collision_modules(module) + self._output_dists: List[nn.Module] = [] + self._create_output_dists() + self._use_index_dedup = use_index_dedup + self._initialize_torch_state() + + def _initialize_torch_state(self) -> None: + self._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict() + shardable_params = set( + self.sharded_parameter_names(prefix="_managed_collision_modules") + ) + + for fqn, tensor in self.state_dict().items(): + if fqn not in shardable_params: + continue + table_name = fqn.split(".")[ + 1 + ] # "_managed_collision_modules.." + shard_offset, shard_size, global_size = self._mc_module_name_shard_metadata[ + table_name + ] + sharded_sizes = list(tensor.shape) + sharded_sizes[0] = shard_size + shard_offsets = [0] * len(sharded_sizes) + shard_offsets[0] = shard_offset + global_sizes = list(tensor.shape) + global_sizes[0] = global_size + + self._model_parallel_mc_buffer_name_to_sharded_tensor[fqn] = ( + ShardedTensor._init_from_local_shards( + [ + Shard( + tensor=tensor, + metadata=ShardMetadata( + shard_offsets=shard_offsets, + shard_sizes=sharded_sizes, + placement=(f"rank:{self._env.rank}/{tensor.device}"), + ), + ) + ], + torch.Size(global_sizes), + process_group=self._env.process_group, + ) + ) + + def _post_state_dict_hook( + module: ShardedManagedCollisionCollection, + destination: Dict[str, torch.Tensor], + prefix: str, + _local_metadata: Dict[str, Any], + ) -> None: + for ( + mc_buffer_name, + sharded_tensor, + ) in module._model_parallel_mc_buffer_name_to_sharded_tensor.items(): + destination_key = f"{prefix}{mc_buffer_name}" + destination[destination_key] = sharded_tensor + + def _load_state_dict_pre_hook( + module: "ShardedManagedCollisionCollection", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + for ( + mc_buffer_name, + _sharded_tensor, + ) in module._model_parallel_mc_buffer_name_to_sharded_tensor.items(): + key = f"{prefix}{mc_buffer_name}" + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + local_shards = state_dict[key].local_shards() + state_dict[key] = local_shards[0].tensor + else: + raise RuntimeError( + f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" + ) + + self._register_state_dict_hook(_post_state_dict_hook) + self._register_load_state_dict_pre_hook( + _load_state_dict_pre_hook, with_module=True + ) + + def _create_managed_collision_modules( + self, module: ManagedCollisionCollection + ) -> None: + + self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() + # To map mch output indices from local to global. key: table_name + self._table_to_offset: Dict[str, int] = {} + + # the split sizes of tables belonging to each sharding. outer len is # shardings + self._sharding_per_table_feature_splits: List[List[int]] = [] + self._input_size_per_table_feature_splits: List[List[int]] = [] + # the split sizes of features per sharding. len is # shardings + self._sharding_feature_splits: List[int] = [] + # the split sizes of features per table. len is # tables sum over all shardings + self._table_feature_splits: List[int] = [] + self._feature_names: List[str] = [] + + # table names of each sharding + self._sharding_tables: List[List[str]] = [] + self._sharding_features: List[List[str]] = [] + + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + self._sharding_tables.append([]) + self._sharding_features.append([]) + self._sharding_per_table_feature_splits.append([]) + self._input_size_per_table_feature_splits.append([]) + + grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + sharding._grouped_embedding_configs + ) + self._sharding_feature_splits.append(len(sharding.feature_names())) + + num_sharding_features = 0 + for group_config in grouped_embedding_configs: + for table in group_config.embedding_tables: + # pyre-ignore [16] + new_min_output_id = table.local_metadata.shard_offsets[0] + # pyre-ignore [16] + new_range_size = table.local_metadata.shard_sizes[0] + output_segments = [ + x.shard_offsets[0] + # pyre-ignore [16] + for x in table.global_metadata.shards_metadata + ] + [table.num_embeddings] + mc_module = module._managed_collision_modules[table.name] + + self._sharding_tables[-1].append(table.name) + self._sharding_features[-1].extend(table.feature_names) + self._feature_names.extend(table.feature_names) + self._managed_collision_modules[table.name] = ( + mc_module.rebuild_with_output_id_range( + output_id_range=( + new_min_output_id, + new_min_output_id + new_range_size, + ), + output_segments=output_segments, + device=self._device, + ) + ) + zch_size = self._managed_collision_modules[table.name].output_size() + input_size = self._managed_collision_modules[ + table.name + ].input_size() + zch_size_by_rank = [ + torch.zeros(1, dtype=torch.int64, device=self._device) + for _ in range(self._env.world_size) + ] + if self.training and self._env.world_size > 1: + dist.all_gather( + zch_size_by_rank, + torch.tensor( + [zch_size], dtype=torch.int64, device=self._device + ), + group=self._env.process_group, + ) + else: + zch_size_by_rank[0] = torch.tensor( + [zch_size], dtype=torch.int64, device=self._device + ) + + # Calculate the sum of all ZCH sizes from rank 0 to list + # index. The last item is the sum of all elements in zch_size_by_rank + zch_size_cumsum = torch.cumsum( + torch.cat(zch_size_by_rank), dim=0 + ).tolist() + + zch_size_sum_before_this_rank = ( + zch_size_cumsum[self._env.rank] - zch_size + ) + # pyre-fixme[6]: For 2nd argument expected `int` + self._mc_module_name_shard_metadata[table.name] = ( + zch_size_sum_before_this_rank, + zch_size, + zch_size_cumsum[-1], + ) + self._table_to_offset[table.name] = new_min_output_id + + self._table_feature_splits.append(len(table.feature_names)) + self._sharding_per_table_feature_splits[-1].append( + self._table_feature_splits[-1] + ) + self._input_size_per_table_feature_splits[-1].append( + input_size, + ) + num_sharding_features += self._table_feature_splits[-1] + + assert num_sharding_features == len( + sharding.feature_names() + ), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}" + + if self._sharding_features[-1] != sharding.feature_names(): + logger.warn( + "The order of tables of this sharding is altered due to grouping: " + f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}" + ) + + logger.info(f"{self._table_feature_splits=}") + logger.info(f"{self._sharding_per_table_feature_splits=}") + logger.info(f"{self._input_size_per_table_feature_splits=}") + logger.info(f"{self._feature_names=}") + logger.info(f"{self._table_to_offset=}") + logger.info(f"{self._sharding_tables=}") + logger.info(f"{self._sharding_features=}") + + def _create_input_dists( + self, + input_feature_names: List[str], + ) -> None: + for sharding, sharding_features in zip( + self._embedding_shardings, + self._sharding_features, + ): + assert isinstance(sharding, BaseRwEmbeddingSharding) + feature_num_buckets: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].buckets() + for f in sharding_features + ] + + input_sizes: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].input_size() + for f in sharding_features + ] + + feature_hash_sizes: List[int] = [] + feature_total_num_buckets: List[int] = [] + for input_size, num_buckets in zip( + input_sizes, + feature_num_buckets, + ): + feature_hash_sizes.append(input_size) + feature_total_num_buckets.append(num_buckets) + + input_dist = RwSparseFeaturesDist( + # pyre-ignore [6] + pg=sharding._pg, + num_features=sharding._get_num_features(), + feature_hash_sizes=feature_hash_sizes, + feature_total_num_buckets=feature_total_num_buckets, + device=sharding._device, + is_sequence=True, + has_feature_processor=sharding._has_feature_processor, + need_pos=False, + keep_original_indices=True, + ) + self._input_dists.append(input_dist) + + # pyre-fixme[16]: `ShardedManagedCollisionCollection` has no attribute + # `_features_order`. + self._features_order: List[int] = [] + for f in self._feature_names: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `append`. + self._features_order.append(input_feature_names.index(f)) + self._features_order = ( + [] + if self._features_order == list(range(len(input_feature_names))) + else self._features_order + ) + self.register_buffer( + "_features_order_tensor", + torch.tensor(self._features_order, device=self._device, dtype=torch.int32), + persistent=False, + ) + if self._use_index_dedup: + self._create_dedup_indices() + + def _create_output_dists( + self, + ) -> None: + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + self._output_dists.append( + RwSequenceEmbeddingDist( + # pyre-ignore [6] + sharding._pg, + sharding._get_num_features(), + sharding._device, + ) + ) + + def _create_dedup_indices(self) -> None: + # validate we can linearize the features irrespective of feature split + assert ( + list( + itertools.accumulate( + [ + hash_input + for input_split in self._input_size_per_table_feature_splits + for hash_input in input_split + ] + ) + )[-1] + <= torch.iinfo(torch.int64).max + ), "EC Dedup requires the mc collection to have a cumuluative 'hash_input_size' kwarg to be less than max int64. Please reduce values of individual tables to meet this constraint (ie. 2**54 is typically a good value)." + for i, (feature_splits, input_splits) in enumerate( + zip( + self._sharding_per_table_feature_splits, + self._input_size_per_table_feature_splits, + ) + ): + cum_f = 0 + cum_i = 0 + hash_offsets = [] + feature_offsets = [] + N = math.ceil(math.log2(len(feature_splits))) + for features, hash_size in zip(feature_splits, input_splits): + hash_offsets += [cum_i for _ in range(features)] + feature_offsets += [cum_f for _ in range(features)] + cum_f += features + cum_i += (2 ** (63 - N) - 1) if hash_size == 0 else hash_size + assert ( + cum_i <= torch.iinfo(torch.int64).max + ), f"Index exceeds max int64, {cum_i=}" + hash_offsets += [cum_i] + feature_offsets += [cum_f] + self.register_buffer( + "_dedup_hash_offsets_{}".format(i), + torch.tensor(hash_offsets, dtype=torch.int64, device=self._device), + persistent=False, + ) + self.register_buffer( + "_dedup_feature_offsets_{}".format(i), + torch.tensor(feature_offsets, dtype=torch.int64, device=self._device), + persistent=False, + ) + + def _dedup_indices( + self, + ctx: ManagedCollisionCollectionContext, + features: List[KeyedJaggedTensor], + ) -> List[KeyedJaggedTensor]: + features_by_sharding = [] + + for i, kjt in enumerate(features): + hash_offsets = self.get_buffer(f"_dedup_hash_offsets_{i}") + feature_offsets = self.get_buffer(f"_dedup_feature_offsets_{i}") + ( + lengths, + offsets, + unique_indices, + reverse_indices, + ) = torch.ops.fbgemm.jagged_unique_indices( + hash_offsets, + feature_offsets, + kjt.offsets().to(torch.int64), + kjt.values().to(torch.int64), + ) + dedup_features = KeyedJaggedTensor( + keys=kjt.keys(), + lengths=lengths, + offsets=offsets, + values=unique_indices, + ) + + ctx.input_features.append(kjt) # pyre-ignore + ctx.reverse_indices.append(reverse_indices) # pyre-ignore + features_by_sharding.append(dedup_features) + return features_by_sharding + + # pyre-ignore [14] + def input_dist( + self, + ctx: ManagedCollisionCollectionContext, + features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KJTList]]: + if self._has_uninitialized_input_dists: + self._create_input_dists(input_feature_names=features.keys()) + self._has_uninitialized_input_dists = False + + with torch.no_grad(): + if self._features_order: + features = features.permute( + # pyre-fixme[6]: For 1st argument expected `List[int]` but got + # `Union[Module, Tensor]`. + self._features_order, + # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` + # but got `Union[Module, Tensor]`. + self._features_order_tensor, + ) + + feature_splits: List[KeyedJaggedTensor] = [] + if self.need_preprocess: + # NOTE: No shared features allowed! + assert ( + len(self._sharding_feature_splits) == 1 + ), "Preprocing only support single sharding type (row-wise)" + table_splits = features.split(self._table_feature_splits) + ti: int = 0 + for i, tables in enumerate(self._sharding_tables): + output: Dict[str, JaggedTensor] = {} + for table in tables: + kjt: KeyedJaggedTensor = table_splits[ti] + mc_module = self._managed_collision_modules[table] + # TODO: change to Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } + mc_input = mc_module.preprocess(mc_input) + output.update(mc_input) + ti += 1 + shard_kjt = KeyedJaggedTensor( + keys=self._sharding_features[i], + values=torch.cat([jt.values() for jt in output.values()]), + lengths=torch.cat([jt.lengths() for jt in output.values()]), + ) + feature_splits.append(shard_kjt) + else: + feature_splits = features.split(self._sharding_feature_splits) + + if self._use_index_dedup: + feature_splits = self._dedup_indices(ctx, feature_splits) + + awaitables = [] + for feature_split, input_dist in zip(feature_splits, self._input_dists): + awaitables.append(input_dist(feature_split)) + ctx.sharding_contexts.append( + SequenceShardingContext( + features_before_input_dist=features, + unbucketize_permute_tensor=( + input_dist.unbucketize_permute_tensor + if isinstance(input_dist, RwSparseFeaturesDist) + else None + ), + ) + ) + + return KJTListSplitsAwaitable(awaitables, ctx) + + def _kjt_list_to_tensor_list( + self, + kjt_list: KJTList, + ) -> List[torch.Tensor]: + remapped_ids_ret: List[torch.Tensor] = [] + # TODO: find a better solution, could be padding + for kjt, tables, splits in zip( + kjt_list, self._sharding_tables, self._sharding_per_table_feature_splits + ): + if len(splits) > 1: + feature_splits = kjt.split(splits) + vals: List[torch.Tensor] = [] + # assert len(feature_splits) == len(sharding.embedding_tables()) + for feature_split, table in zip(feature_splits, tables): + offset = self._table_to_offset[table] + vals.append(feature_split.values() + offset) + remapped_ids_ret.append(torch.cat(vals).view(-1, 1)) + else: + remapped_ids_ret.append(kjt.values() + self._table_to_offset[tables[0]]) + return remapped_ids_ret + + def global_to_local_index( + self, + jt_dict: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + for table, jt in jt_dict.items(): + jt._values = jt.values() - self._table_to_offset[table] + return jt_dict + + def compute( + self, + ctx: ManagedCollisionCollectionContext, + dist_input: KJTList, + ) -> KJTList: + remapped_kjts: List[KeyedJaggedTensor] = [] + + # per shard + for features, sharding_ctx, tables, splits, fns in zip( + dist_input, + ctx.sharding_contexts, + self._sharding_tables, + self._sharding_per_table_feature_splits, + self._sharding_features, + ): + assert isinstance(sharding_ctx, SequenceShardingContext) + sharding_ctx.lengths_after_input_dist = features.lengths().view( + -1, features.stride() + ) + + values: torch.Tensor + if len(splits) > 1: + # features per shard split by tables + feature_splits = features.split(splits) + output: Dict[str, JaggedTensor] = {} + for table, kjt in zip(tables, feature_splits): + # TODO: Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + # TODO: improve this temp solution by passing real weights + weights=torch.tensor(kjt.length_per_key()), + ) + } + mcm = self._managed_collision_modules[table] + mc_input = mcm.profile(mc_input) + mc_input = mcm.remap(mc_input) + mc_input = self.global_to_local_index(mc_input) + output.update(mc_input) + values = torch.cat([jt.values() for jt in output.values()]) + else: + table: str = tables[0] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=features.values(), + lengths=features.lengths(), + # TODO: improve this temp solution by passing real weights + weights=torch.tensor(features.length_per_key()), + ) + } + mcm = self._managed_collision_modules[table] + mc_input = mcm.profile(mc_input) + mc_input = mcm.remap(mc_input) + mc_input = self.global_to_local_index(mc_input) + values = mc_input[table].values() + + remapped_kjts.append( + KeyedJaggedTensor( + keys=fns, + values=values, + lengths=features.lengths(), + # original weights instead of features splits + weights=features.weights_or_none(), + ) + ) + return KJTList(remapped_kjts) + + def evict(self) -> Dict[str, Optional[torch.Tensor]]: + evictions: Dict[str, Optional[torch.Tensor]] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + global_indices_to_evict = managed_collision_module.evict() + local_indices_to_evict = None + if global_indices_to_evict is not None: + local_indices_to_evict = ( + global_indices_to_evict - self._table_to_offset[table] + ) + evictions[table] = local_indices_to_evict + return evictions + + def open_slots(self) -> Dict[str, torch.Tensor]: + open_slots: Dict[str, torch.Tensor] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + open_slots[table] = managed_collision_module.open_slots() + return open_slots + + def output_dist( + self, + ctx: ManagedCollisionCollectionContext, + output: KJTList, + ) -> LazyAwaitable[KeyedJaggedTensor]: + global_remapped = self._kjt_list_to_tensor_list(output) + awaitables_per_sharding: List[Awaitable[torch.Tensor]] = [] + features_before_all2all_per_sharding: List[KeyedJaggedTensor] = [] + + for odist, remapped_ids, sharding_ctx in zip( + self._output_dists, + global_remapped, + ctx.sharding_contexts, + ): + awaitables_per_sharding.append(odist(remapped_ids, sharding_ctx)) + features_before_all2all_per_sharding.append( + # pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but + # got `Optional[KeyedJaggedTensor]`. + sharding_ctx.features_before_input_dist + ) + return ManagedCollisionCollectionAwaitable( + awaitables_per_sharding=awaitables_per_sharding, + features_per_sharding=features_before_all2all_per_sharding, + embedding_names_per_sharding=self._embedding_names_per_sharding, + need_indices=False, + features_to_permute_indices=None, + ) + + def create_context(self) -> ManagedCollisionCollectionContext: + return ManagedCollisionCollectionContext(sharding_contexts=[]) + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for name, module in self._managed_collision_modules.items(): + module_prefix = append_prefix(prefix, name) + for name, _ in module.named_buffers(): + if name in [ + "_output_segments_tensor", + "_current_iter_tensor", + "_scalar_logger._scalar_logger_steps", + ]: + continue + if name in module._non_persistent_buffers_set: + continue + yield append_prefix(module_prefix, name) + for name, _ in module.named_parameters(): + yield append_prefix(module_prefix, name) + + @property + def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: + return ManagedCollisionCollection + + +class ManagedCollisionCollectionSharder( + BaseEmbeddingSharder[ManagedCollisionCollection] +): + def __init__( + self, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + + def shard( + self, + module: ManagedCollisionCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + device: Optional[torch.device] = None, + use_index_dedup: bool = False, + ) -> ShardedManagedCollisionCollection: + + if device is None: + device = torch.device("cpu") + + return ShardedManagedCollisionCollection( + module, + params, + env=env, + device=device, + embedding_shardings=embedding_shardings, + use_index_dedup=use_index_dedup, + ) + + def shardable_parameters( + self, module: ManagedCollisionCollection + ) -> Dict[str, torch.nn.Parameter]: + # TODO: standalone sharding + raise NotImplementedError() + + @property + def module_type(self) -> Type[ManagedCollisionCollection]: + return ManagedCollisionCollection + + def sharding_types(self, compute_device_type: str) -> List[str]: + types = [ + ShardingType.ROW_WISE.value, + ] + return types + + +@torch.fx.wrap +def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: + return torch.cat([jt.values() for jt in jd.values()]) + + +@torch.fx.wrap +def update_jagged_tensor_dict( + output: Dict[str, JaggedTensor], new_dict: Dict[str, JaggedTensor] +) -> Dict[str, JaggedTensor]: + output.update(new_dict) + return output + + +class ShardedMCCRemapper(nn.Module): + def __init__( + self, + table_feature_splits: List[int], + fns: List[str], + managed_collision_modules: nn.ModuleDict, + shard_metadata: Dict[str, List[int]], + ) -> None: + super().__init__() + self._table_feature_splits: List[int] = table_feature_splits + self._fns: List[str] = fns + self.zchs = managed_collision_modules + logger.info(f"registered zchs: {self.zchs=}") + + # shard_size, shard_offset + self._shard_metadata: Dict[str, List[int]] = shard_metadata + self._table_to_offset: Dict[str, int] = { + table: offset[0] for table, offset in shard_metadata.items() + } + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + # features per shard split by tables + feature_splits = features.split(self._table_feature_splits) + output: Dict[str, JaggedTensor] = {} + for i, (table, mc_module) in enumerate(self.zchs.items()): + kjt: KeyedJaggedTensor = feature_splits[i] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + weights=_get_length_per_key(kjt), + ) + } + remapped_input = mc_module(mc_input) + mc_input = self.global_to_local_index(remapped_input) + output[table] = remapped_input[table] + + values: torch.Tensor = _cat_jagged_values(output) + return KeyedJaggedTensor( + keys=self._fns, + values=values, + lengths=features.lengths(), + # original weights instead of features splits + weights=features.weights_or_none(), + ) + + def global_to_local_index( + self, + jt_dict: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return _fx_global_to_local_index(jt_dict, self._table_to_offset) + + +class ShardedQuantManagedCollisionCollection( + ShardedModule[ + KJTList, + KJTList, + KeyedJaggedTensor, + ManagedCollisionCollectionContext, + ] +): + def __init__( + self, + module: ManagedCollisionCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: torch.device, + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + self._env: ShardingEnv = ( + env + if not isinstance(env, Dict) + else embedding_shardings[0]._env # pyre-ignore[16] + ) + self._device = device + self.need_preprocess: bool = module.need_preprocess + self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( + copy.deepcopy(table_name_to_parameter_sharding) + ) + # TODO: create a MCSharding type instead of leveraging EmbeddingSharding + self._embedding_shardings = embedding_shardings + + self._embedding_names_per_sharding: List[List[str]] = [] + for sharding in self._embedding_shardings: + # TODO: support TWRW sharding + assert isinstance( + sharding, BaseRwEmbeddingSharding + ), "Only ROW_WISE sharding is supported." + self._embedding_names_per_sharding.append(sharding.embedding_names()) + + self._feature_to_table: Dict[str, str] = module._feature_to_table + self._table_to_features: Dict[str, List[str]] = module._table_to_features + self._has_uninitialized_input_dists: bool = True + self._input_dists: torch.nn.ModuleList = torch.nn.ModuleList([]) + self._managed_collision_modules: nn.ModuleDict = nn.ModuleDict() + self._create_managed_collision_modules(module) + self._features_order: List[int] = [] + + def _create_managed_collision_modules( + self, module: ManagedCollisionCollection + ) -> None: + + self._managed_collision_modules_per_rank: List[torch.nn.ModuleDict] = [ + torch.nn.ModuleDict() for _ in range(self._env.world_size) + ] + self._shard_metadata_per_rank: List[Dict[str, List[int]]] = [ + defaultdict() for _ in range(self._env.world_size) + ] + self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() + # To map mch output indices from local to global. key: table_name + self._table_to_offset: Dict[str, int] = {} + + # the split sizes of tables belonging to each sharding. outer len is # shardings + self._sharding_per_table_feature_splits: List[List[int]] = [] + self._input_size_per_table_feature_splits: List[List[int]] = [] + # the split sizes of features per sharding. len is # shardings + self._sharding_feature_splits: List[int] = [] + # the split sizes of features per table. len is # tables sum over all shardings + self._table_feature_splits: List[int] = [] + self._feature_names: List[str] = [] + + # table names of each sharding + self._sharding_tables: List[List[str]] = [] + self._sharding_features: List[List[str]] = [] + + logger.info(f"_create_managed_collision_modules {self._embedding_shardings=}") + + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + self._sharding_tables.append([]) + self._sharding_features.append([]) + self._sharding_per_table_feature_splits.append([]) + self._input_size_per_table_feature_splits.append([]) + + grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + sharding._grouped_embedding_configs + ) + self._sharding_feature_splits.append(len(sharding.feature_names())) + + num_sharding_features = 0 + for group_config in grouped_embedding_configs: + for table in group_config.embedding_tables: + # pyre-ignore + global_meta_data = table.global_metadata.shards_metadata + output_segments = [ + x.shard_offsets[0] + for x in table.global_metadata.shards_metadata + ] + [table.num_embeddings] + mc_module = module._managed_collision_modules[table.name] + mc_module._is_inference = True + self._managed_collision_modules[table.name] = mc_module + self._sharding_tables[-1].append(table.name) + self._sharding_features[-1].extend(table.feature_names) + self._feature_names.extend(table.feature_names) + logger.info( + f"global_meta_data for table {table} is {global_meta_data}" + ) + + for i in range(self._env.world_size): + new_min_output_id = global_meta_data[i].shard_offsets[0] + new_range_size = global_meta_data[i].shard_sizes[0] + self._managed_collision_modules_per_rank[i][table.name] = ( + mc_module.rebuild_with_output_id_range( + output_id_range=( + new_min_output_id, + new_min_output_id + new_range_size, + ), + output_segments=output_segments, + device=( + torch.device("cpu") + if self._device.type == "cpu" + else torch.device(f"{self._device.type}:{i}") + ), + ) + ) + + self._managed_collision_modules_per_rank[i][ + table.name + ].training = False + self._shard_metadata_per_rank[i][table.name] = [ + new_min_output_id, + new_range_size, + ] + + input_size = self._managed_collision_modules[ + table.name + ].input_size() + + self._table_feature_splits.append(len(table.feature_names)) + self._sharding_per_table_feature_splits[-1].append( + self._table_feature_splits[-1] + ) + self._input_size_per_table_feature_splits[-1].append( + input_size, + ) + num_sharding_features += self._table_feature_splits[-1] + + assert num_sharding_features == len( + sharding.feature_names() + ), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}" + + if self._sharding_features[-1] != sharding.feature_names(): + logger.warn( + "The order of tables of this sharding is altered due to grouping: " + f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}" + ) + + logger.info(f"{self._table_feature_splits=}") + logger.info(f"{self._sharding_per_table_feature_splits=}") + logger.info(f"{self._input_size_per_table_feature_splits=}") + logger.info(f"{self._feature_names=}") + # logger.info(f"{self._table_to_offset=}") + logger.info(f"{self._sharding_tables=}") + logger.info(f"{self._sharding_features=}") + logger.info(f"{self._managed_collision_modules_per_rank=}") + logger.info(f"{self._shard_metadata_per_rank=}") + + def _create_input_dists( + self, + input_feature_names: List[str], + feature_device: Optional[torch.device] = None, + ) -> None: + feature_names: List[str] = [] + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + + emb_sharding = [] + sharding_features = [] + for embedding_table_group in sharding._grouped_embedding_configs_per_rank[ + 0 + ]: + for table in embedding_table_group.embedding_tables: + shard_split_offsets = [ + shard.shard_offsets[0] + # pyre-fixme[16]: `Optional` has no attribute `shards_metadata`. + for shard in table.global_metadata.shards_metadata + ] + # pyre-fixme[16]: Optional has no attribute size. + shard_split_offsets.append(table.global_metadata.size[0]) + emb_sharding.extend( + [shard_split_offsets] * len(table.embedding_names) + ) + sharding_features.extend(table.feature_names) + + feature_num_buckets: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].buckets() + for f in sharding_features + ] + + input_sizes: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].input_size() + for f in sharding_features + ] + + feature_hash_sizes: List[int] = [] + feature_total_num_buckets: List[int] = [] + for input_size, num_buckets in zip( + input_sizes, + feature_num_buckets, + ): + feature_hash_sizes.append(input_size) + feature_total_num_buckets.append(num_buckets) + + input_dist = InferRwSparseFeaturesDist( + world_size=sharding._world_size, + num_features=sharding._get_num_features(), + feature_hash_sizes=feature_hash_sizes, + feature_total_num_buckets=feature_total_num_buckets, + device=self._device, + is_sequence=True, + has_feature_processor=sharding._has_feature_processor, + need_pos=False, + embedding_shard_metadata=emb_sharding, + keep_original_indices=True, + ) + self._input_dists.append(input_dist) + + feature_names.extend(sharding_features) + + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + self._features_order = ( + [] + if self._features_order == list(range(len(input_feature_names))) + else self._features_order + ) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=feature_device, dtype=torch.int32 + ), + persistent=False, + ) + + # pyre-ignore + def input_dist( + self, + ctx: ManagedCollisionCollectionContext, + features: KeyedJaggedTensor, + ) -> ListOfKJTList: + if self._has_uninitialized_input_dists: + self._create_input_dists( + input_feature_names=features.keys(), feature_device=features.device() + ) + self._has_uninitialized_input_dists = False + + with torch.no_grad(): + if self._features_order: + features = features.permute( + self._features_order, + self._features_order_tensor, # pyre-ignore + ) + + feature_splits: List[KeyedJaggedTensor] = [] + if self.need_preprocess: + # NOTE: No shared features allowed! + assert ( + len(self._sharding_feature_splits) == 1 + ), "Preprocing only support single sharding type (row-wise)" + table_splits = features.split(self._table_feature_splits) + ti: int = 0 + for i, tables in enumerate(self._sharding_tables): + output: Dict[str, JaggedTensor] = {} + for table in tables: + kjt: KeyedJaggedTensor = table_splits[ti] + mc_module = self._managed_collision_modules[table] + # TODO: change to Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } + mc_input = mc_module.preprocess(mc_input) + output.update(mc_input) + ti += 1 + shard_kjt = KeyedJaggedTensor( + keys=self._sharding_features[i], + values=torch.cat([jt.values() for jt in output.values()]), + lengths=torch.cat([jt.lengths() for jt in output.values()]), + ) + feature_splits.append(shard_kjt) + else: + feature_splits = features.split(self._sharding_feature_splits) + + input_dist_result_list = [] + for feature_split, input_dist in zip(feature_splits, self._input_dists): + out = input_dist(feature_split) + input_dist_result_list.append(out.features) + ctx.sharding_contexts.append( + InferSequenceShardingContext( + features=out.features, + features_before_input_dist=features, + unbucketize_permute_tensor=( + out.unbucketize_permute_tensor + if isinstance(input_dist, InferRwSparseFeaturesDist) + else None + ), + bucket_mapping_tensor=out.bucket_mapping_tensor, + bucketized_length=out.bucketized_length, + ) + ) + + return ListOfKJTList(input_dist_result_list) + + def create_mcc_remappers(self) -> List[List[ShardedMCCRemapper]]: + ret: List[List[ShardedMCCRemapper]] = [] + # per shard + for table_feature_splits, fns in zip( + self._sharding_per_table_feature_splits, + self._sharding_features, + ): + sharding_ret: List[ShardedMCCRemapper] = [] + for i, mcms in enumerate(self._managed_collision_modules_per_rank): + sharding_ret.append( + ShardedMCCRemapper( + table_feature_splits=table_feature_splits, + fns=fns, + managed_collision_modules=mcms, + shard_metadata=self._shard_metadata_per_rank[i], + ) + ) + ret.append(sharding_ret) + return ret + + def compute( + self, + ctx: ManagedCollisionCollectionContext, + rank: int, + dist_input: KJTList, + ) -> KJTList: + raise NotImplementedError() + + # pyre-ignore + def output_dist( + self, + ctx: ManagedCollisionCollectionContext, + output: KJTList, + ) -> KeyedJaggedTensor: + raise NotImplementedError() + + def create_context(self) -> ManagedCollisionCollectionContext: + return ManagedCollisionCollectionContext(sharding_contexts=[]) + + @property + def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: + return ManagedCollisionCollection + + +class InferManagedCollisionCollectionSharder(ManagedCollisionCollectionSharder): + # pyre-ignore + def shard( + self, + module: ManagedCollisionCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + device: Optional[torch.device] = None, + ) -> ShardedQuantManagedCollisionCollection: + + if device is None: + device = torch.device("cpu") + + return ShardedQuantManagedCollisionCollection( + module, + params, + env=env, + device=device, + embedding_shardings=embedding_shardings, + ) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 359a460e5..23ce0165c 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -5,28 +5,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc -from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple +import copy +import logging as logger +from collections import defaultdict, OrderedDict +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torch.autograd.profiler import record_function +from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as ddp_default_hooks, +) from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - sharder_name, - Topology, -) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( + EnumerableShardingSpec, ModuleSharder, ShardedModule, ShardingEnv, + ShardingEnv2D, ShardingPlan, ) from torchrec.distributed.utils import ( @@ -34,10 +45,17 @@ append_prefix, copy_to_device, filter_state_dict, + sharded_model_copy, ) from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + _DDP_STATE_DICT_PREFIX = "module." @@ -62,34 +80,38 @@ class DefaultDataParallelWrapper(DataParallelWrapper): Default data parallel wrapper, which applies data parallel to all unsharded modules. """ - def wrap( + def __init__( + self, + bucket_cap_mb: int = 25, + static_graph: bool = True, + find_unused_parameters: bool = False, + allreduce_comm_precision: Optional[str] = None, + params_to_ignore: Optional[List[str]] = None, + ddp_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self._bucket_cap_mb: int = bucket_cap_mb + self._static_graph: bool = static_graph + self._find_unused_parameters: bool = find_unused_parameters + self._allreduce_comm_precision = allreduce_comm_precision + self._additional_params_to_ignore: Set[str] = set(params_to_ignore or []) + self._ddp_kwargs: Dict[str, Any] = ddp_kwargs or {} + + def _ddp_wrap( self, dmp: "DistributedModelParallel", env: ShardingEnv, device: torch.device, + ddp_ignore_param_names: Set[str], ) -> None: - if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance( - dmp._dmp_wrapped_module, FullyShardedDataParallel - ): - return pg = env.process_group if pg is None: raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv") - sharded_parameter_names = { - key - for key in DistributedModelParallel._sharded_parameter_names( - dmp._dmp_wrapped_module - ) - } - all_paramemeter_names = {key for key, _ in dmp.named_parameters()} - if sharded_parameter_names == all_paramemeter_names: + all_parameter_names = set(dict(dmp.named_parameters()).keys()) + if len(all_parameter_names - ddp_ignore_param_names) == 0: return - DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( module=dmp._dmp_wrapped_module, - params_and_buffers_to_ignore=[ - key for key in all_paramemeter_names if key in sharded_parameter_names - ], + params_and_buffers_to_ignore=ddp_ignore_param_names, ) # initialize DDP dmp._dmp_wrapped_module = cast( @@ -100,9 +122,40 @@ def wrap( process_group=pg, gradient_as_bucket_view=True, broadcast_buffers=False, - static_graph=True, + static_graph=self._static_graph, + find_unused_parameters=self._find_unused_parameters, + bucket_cap_mb=self._bucket_cap_mb, + **self._ddp_kwargs, ), ) + if self._allreduce_comm_precision == "fp16": + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + dmp._dmp_wrapped_module.register_comm_hook( + None, ddp_default_hooks.fp16_compress_hook + ) + elif self._allreduce_comm_precision == "bf16": + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + dmp._dmp_wrapped_module.register_comm_hook( + None, ddp_default_hooks.bf16_compress_hook + ) + + def wrap( + self, + dmp: "DistributedModelParallel", + env: ShardingEnv, + device: torch.device, + ) -> None: + if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance( + dmp._dmp_wrapped_module, FullyShardedDataParallel + ): + return + sharded_parameter_names = set( + DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module) + ) + params_to_ignore = sharded_parameter_names.union( + self._additional_params_to_ignore + ) + self._ddp_wrap(dmp, env, device, params_to_ignore) def get_unwrapped_module(module: nn.Module) -> nn.Module: @@ -185,6 +238,7 @@ def __init__( torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self.init_parameters = init_parameters + self._ddp_wrapped: bool = False if env is None: @@ -199,8 +253,9 @@ def __init__( if sharders is None: sharders = get_default_sharders() - self._sharder_map: Dict[str, ModuleSharder[nn.Module]] = { - sharder_name(sharder.module_type): sharder for sharder in sharders + + self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = { + sharder.module_type: sharder for sharder in sharders } if data_parallel_wrapper is None: @@ -221,7 +276,6 @@ def __init__( else: plan = planner.plan(module, sharders) self._plan: ShardingPlan = plan - self._dmp_wrapped_module: nn.Module = self._init_dmp(module) self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module) @@ -276,8 +330,15 @@ def copy( `ShardedModule` for inference). """ assert isinstance(device, torch.device) - copy_dmp = copy_to_device(self, self.device, device) - return cast(DistributedModelParallel, copy_dmp) + # dmp code deep copy + with sharded_model_copy(device=None): + copy_dmp = copy.deepcopy(self) + # tensor resident module deep copy + copy_dmp_wrapped_module = copy_to_device( + self._dmp_wrapped_module, self.device, device + ) + copy_dmp._dmp_wrapped_module = copy_dmp_wrapped_module + return copy_dmp def _init_dmp(self, module: nn.Module) -> nn.Module: return self._shard_modules_impl(module) @@ -317,12 +378,13 @@ def _shard_modules_impl( # shardable module module_sharding_plan = self._plan.get_plan_for_module(path) if module_sharding_plan: - sharder_key = sharder_name(type(module)) + sharder_key = type(module) module = self._sharder_map[sharder_key].shard( module, module_sharding_plan, self._env, self.device, + path, ) return module @@ -349,11 +411,11 @@ def init_parameters(module: nn.Module) -> None: has_meta_param = True for name, buffer in module._buffers.items(): if isinstance(buffer, torch.Tensor) and buffer.device.type == "meta": - module._buffers[name] = torch.empty_like(buffer, device=self.device) + module._buffers[name] = torch.zeros_like(buffer, device=self.device) # Init parameters if at least one parameter is over 'meta' device. if has_meta_param and hasattr(module, "reset_parameters"): - # pyre-ignore [29] + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. module.reset_parameters() module.apply(init_parameters) @@ -424,7 +486,7 @@ def _load_state_dict( state_dict, prefix ) add_prefix_to_state_dict(state_dict, prefix + _DDP_STATE_DICT_PREFIX) - if isinstance(module, ShardedModule): + if getattr(module, "_FORCE_STATE_DICT_LOAD", False): return module.load_state_dict(state_dict, strict=strict) else: module._load_from_state_dict( @@ -458,13 +520,23 @@ def _named_parameters( yield from module.named_parameters(prefix, recurse=False) for name, child in module.named_children(): yield from self._named_parameters( - child, append_prefix(prefix, name), recurse, strip_ddp + child, + append_prefix(prefix, name), + recurse, + strip_ddp, ) def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - gen = self._named_parameters(self.module, prefix, recurse) + gen = self._named_parameters( + self.module, + prefix, + recurse, + ) memo = set() for key, param in gen: if param in memo: @@ -474,9 +546,15 @@ def named_parameters( yield key, param def bare_named_parameters( - self, prefix: str = "", recurse: bool = True + self, + prefix: str = "", + recurse: bool = True, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - gen = self._named_parameters(self.module, prefix, recurse, False) + gen = self._named_parameters( + self.module, + prefix, + recurse, + ) memo = set() for key, param in gen: if param in memo: @@ -521,7 +599,7 @@ def named_buffers( yield key, param @property - def fused_optimizer(self) -> KeyedOptimizer: + def fused_optimizer(self) -> CombinedOptimizer: return self._optim @property @@ -533,3 +611,349 @@ def _reset_parameters(module: nn.Module) -> None: for _, m in module.named_modules(): if hasattr(m, "reset_parameters"): m.reset_parameters() + + +class DMPCollection(DistributedModelParallel): + """ + A wrapper around DistributedModelParallel that allows for multiple DMPs to be created and managed together. + + This class implements a 2D parallelism model where a DMP is sharded over a subset of ranks. + The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. + This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. + + Example Use Case: + Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: + - Group 0, DMP 0: [0, 2, 4, 6] + - Group 1, DMP 1: [1, 3, 5, 7] + + Each group receives an identical sharding plan for their local world size and ranks. + If we have one table sharded in each DMP, with one shard on each rank in the group, + each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. + The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. + + Notes: + - DTensor must be used for state dict for checkpointing to work correctly. + - The expected sharding plan should be sharded across sharding_group_size (sharding group world size) + and broadcasted to all ranks (`planner.collective_plan(..)`). + + Args: + module (nn.Module): The module to be sharded. + device (torch.device): The device to use for the sharded module. + plan (ShardingPlan): The sharding plan to use, created for sharding group world size. + sharding_group_size (int): The number of GPUs to model parallel shard the embedding tables over + world_size (int): The total number of GPUs. + global_pg (dist.ProcessGroup): The global process group. + node_group_size (Optional[int]): Specify a logical group size for a node for TWRW/GRID sharding schemes + sharders (Optional[List[ModuleSharder[torch.nn.Module]]]): The sharders to use. + init_data_parallel (bool): Whether to initialize data parallelism. + init_parameters (bool): Whether to initialize parameters. + data_parallel_wrapper (Optional[DataParallelWrapper]): The data parallel wrapper to use. + + Example:: + + @torch.no_grad() + def init_weights(m): + if isinstance(m, nn.Linear): + m.weight.fill_(1.0) + elif isinstance(m, EmbeddingBagCollection): + for param in m.parameters(): + init.kaiming_normal_(param) + + m = MyModel(device='meta') + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=global_world_size, + local_world_size=sharding_group_size, + ), + constraints=constraints, + ) + plan = planner.collective_plan(m, sharders, global_pg) + m = DMPCollection( + module=m, + sharding_group_size=sharding_group_size, + world_size=global_world_size, + global_pg=global_pg, + plan=plan, + ) + m.apply(init_weights) + """ + + def __init__( + self, + module: nn.Module, + device: torch.device, + plan: ShardingPlan, + world_size: int, + sharding_group_size: int, + global_pg: dist.ProcessGroup, + node_group_size: Optional[int] = None, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_data_parallel: bool = True, + init_parameters: bool = True, + data_parallel_wrapper: Optional[DataParallelWrapper] = None, + use_inter_host_allreduce: bool = False, + custom_all_reduce: Optional[Callable[[List[torch.Tensor]], None]] = None, + ) -> None: + assert device.type == "cuda", "DMPCollection only supports CUDA" + self._device = device + self._pg: dist.ProcessGroup = global_pg + self._plan: ShardingPlan = plan + self._device_mesh: DeviceMesh = None # pyre-ignore[8] + self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._global_rank: int = dist.get_rank(global_pg) + self._custom_all_reduce = custom_all_reduce + + self._device_mesh, self._sharding_pg, self._replica_pg = ( + self._create_process_groups( + global_rank=self._global_rank, + world_size=world_size, + local_size=sharding_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, + ) + ) + + self._remap_sharding_plan( + plan=plan, + rank=self._global_rank, + step=world_size // sharding_group_size, + sharding_group_size=sharding_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, + ) + super().__init__( + module, + ShardingEnv2D( + global_pg=self._pg, + sharding_pg=self._sharding_pg, + device_mesh=self._device_mesh, + node_group_size=node_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, + ), + device, + plan, + sharders, + init_data_parallel, + init_parameters, + data_parallel_wrapper, + ) + # post DMP init, we group sharded modules for parameter sync + self._modules_to_sync: List[nn.Module] = self._group_sharded_modules() + + def sync(self, include_optimizer_state: bool = True) -> None: + """ + Syncs the DMP weights across the allreduce (inter) process group + + This method is called after each train step to synchronize the weights of the sharded modules. + It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights, + which averages the weights across all processes in the inter-process group. + + The default CUDA stream is used for the all-reduce operation, and the method does not return any value. + + Args: + include_optimizer_state (bool): Flag to include optimizer state syncing upon call + """ + assert self._replica_pg is not None, "replica_pg is not initialized!" + all_weights_by_dtype: dict[torch.dtype, List[torch.Tensor]] = defaultdict(list) + + for emb_kernel in self._modules_to_sync: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + for w in emb_kernel.split_embedding_weights(): + all_weights_by_dtype[w.dtype].append(w) + + opts = None + if self._custom_all_reduce is None: + opts = dist.AllreduceCoalescedOptions() + opts.reduceOp = dist.ReduceOp.AVG + self._allreduce_tensors(all_weights_by_dtype, "## 2d_weight_sync ##", opts) + + if include_optimizer_state: + optimizer_tensors_by_dtype: Dict[torch.dtype, List[torch.Tensor]] = ( + defaultdict(list) + ) + for emb_kernel in self._modules_to_sync: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + optimizer_states = emb_kernel.get_optimizer_state() + for state in optimizer_states: + opt_tensor = state["sum"] + optimizer_tensors_by_dtype[opt_tensor.dtype].append(opt_tensor) + if optimizer_tensors_by_dtype: + self._allreduce_tensors( + optimizer_tensors_by_dtype, "## 2d_optimizer_sync ##", opts + ) + + def _allreduce_tensors( + self, + tensors_dict: Dict[torch.dtype, List[torch.Tensor]], + annotation: str, + opts: Optional[dist.AllreduceCoalescedOptions] = None, + ) -> None: + """ + Helper to perform all reduce on given tensors, uses custom all reduce function if provided + We perform all reduce per tensor dtype per collective constraints. + """ + + custom_all_reduce = self._custom_all_reduce + if custom_all_reduce is not None: + + def _all_reduce(tensors: List[torch.Tensor]) -> None: + with record_function(f"{annotation}_custom_hook"): + custom_all_reduce(tensors) + + else: + + def _all_reduce(tensors: List[torch.Tensor]) -> None: + with record_function(annotation): + self._replica_pg.allreduce_coalesced(tensors, opts=opts).wait() + + for tensor_list in tensors_dict.values(): + _all_reduce(tensor_list) + + def set_all_reduce_hook( + self, + reduce_hook: Callable[[List[torch.Tensor]], None], + ) -> None: + """ + Replace default all reduce with custom callable. Users can alternatively + pass in the custom all reduce function through the constructor. The hook + expects the user to handle distributed communication call, associated + process group, and stream synchronization. + + Args: + reduce_hook (Callable[[List[torch.Tensor]], torch.Tensor]): The custom all reduce function to use for + embedding weights and optimizer states + """ + if self._custom_all_reduce is not None: + logger.warning( + "[TorchRec 2D Parallel] Custom all reduce function already defined, overriding with new callable" + ) + self._custom_all_reduce = reduce_hook + + def _create_process_groups( + self, + global_rank: int, + world_size: int, + local_size: int, + use_inter_host_allreduce: bool = False, + ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: + """ + Creates process groups for sharding and replication, the process groups + are created using the DeviceMesh API. + + Args: + global_rank (int): The global rank of the current process. + world_size (int): The total number of ranks. + local_size (int): The number of ranks per sharding group. + + Returns: + Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, + replication process group, and allreduce process group. + """ + peer_matrix = [] + mesh, sharding_pg, replica_pg = None, None, None + + logger.warning(f"[2D] Use inter host all reduce: {use_inter_host_allreduce}") + + if use_inter_host_allreduce: + # We shard on continuous set of ranks and nodes. Thereby forcing our all reduce to be inter host. + # Under this scheme sharding types such as TWRW and GRID will now take + # advantage of intra node comms as a result of the continuous set of ranks. + peer_matrix = [ + list(range(i, i + local_size)) for i in range(0, world_size, local_size) + ] + else: + step = world_size // local_size + for group_rank in range(world_size // local_size): + peers = [step * r + group_rank for r in range(local_size)] + peer_matrix.append(peers) + + mesh = DeviceMesh( + device_type=self._device.type, + mesh=peer_matrix, + mesh_dim_names=("replicate", "shard"), + ) + + logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + sharding_pg = mesh.get_group(mesh_dim="shard") + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]" + ) + replica_pg = mesh.get_group(mesh_dim="replicate") + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]" + ) + + return mesh, sharding_pg, replica_pg + + def _remap_sharding_plan( + self, + plan: ShardingPlan, + rank: int, + step: int, + sharding_group_size: int, + use_inter_host_allreduce: bool = False, + ) -> None: + """ + Remaps the sharding plan to the local replica process group ranks + ShardingPlan is remapped inplace. + + As an example, + ShardingPlan for created for ranks [0, 2, 4, 6] is remapped to ranks [1, 3, 5, 7] + + Args: + plan (ShardingPlan): The original sharding plan. + global_rank (int): The global rank of the current process. + num_nodes (int): The number of nodes. + """ + group_start = rank % step + for key in plan.plan: + # pyre-ignore[16] + for _, param_sharding in plan.plan[key].items(): + new_ranks = [] + if use_inter_host_allreduce: + group = rank // sharding_group_size + new_ranks = [ + shard_rank + (group * sharding_group_size) + for shard_rank in param_sharding.ranks + ] + else: + for shard_rank in param_sharding.ranks: + new_ranks.append(shard_rank * step + group_start) + param_sharding.ranks = new_ranks + + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): + shards = param_sharding.sharding_spec.shards + if shards is not None: + for shard in shards: + if use_inter_host_allreduce: + shard_rank = shard.placement._rank + ( + (rank // sharding_group_size) * sharding_group_size + ) + else: + shard_rank = shard.placement._rank * step + group_start + shard.placement = _remote_device( + f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}" + ) + return + + def _group_sharded_modules( + self, + ) -> List[nn.Module]: + # Post init DMP, save the embedding kernels + sharded_modules: List[nn.Module] = [] + + def _find_sharded_modules( + module: nn.Module, + ) -> None: + if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): + sharded_modules.append(module) + if hasattr(module, "_lookups"): + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is + # not a function. + for lookup in module._lookups: + _find_sharded_modules(lookup) + return + for _, child in module.named_children(): + _find_sharded_modules(child) + + _find_sharded_modules(self._dmp_wrapped_module) + return sharded_modules diff --git a/torchrec/distributed/object_pool.py b/torchrec/distributed/object_pool.py new file mode 100644 index 000000000..62d5e932c --- /dev/null +++ b/torchrec/distributed/object_pool.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from abc import abstractmethod +from typing import Generic, Type + +import torch +from torch._prims_common import is_integer_dtype +from torchrec.distributed.types import ( + Awaitable, + DistOut, + LazyAwaitable, + Out, + ShardedModule, + ShrdCtx, +) +from torchrec.modules.object_pool import ObjectPool + + +class ShardedObjectPool( + Generic[Out, DistOut, ShrdCtx], + ObjectPool[Out], + ShardedModule[torch.Tensor, DistOut, Out, ShrdCtx], +): + """ + An abstract distributed K-V store supports update and lookup on torch.Tensor and KeyedJaggedTensor. + + To use the update() function, users need to implement _update_preproc(), _ids_dist(), _update_local(), update_value_dist() + To use the lookup() function, users need to implement _ids_dist(), _lookup_local(), lookup_value_dist() + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def _update_preproc(self, values: Out) -> Out: + """ + Sanity check and preproc input values + """ + ... + + @abstractmethod + def _update_ids_dist( + self, ctx: ShrdCtx, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: ... + + @abstractmethod + def _update_local( + self, ctx: ShrdCtx, ids: torch.Tensor, values: DistOut + ) -> None: ... + + @abstractmethod + def _update_values_dist(self, ctx: ShrdCtx, values: Out) -> Awaitable[DistOut]: ... + + @abstractmethod + def _lookup_ids_dist( + self, ctx: ShrdCtx, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: ... + + @abstractmethod + def _lookup_local(self, ctx: ShrdCtx, ids: torch.Tensor) -> DistOut: ... + + @abstractmethod + def _lookup_values_dist( + self, ctx: ShrdCtx, values: DistOut + ) -> LazyAwaitable[Out]: ... + + @abstractmethod + def create_context(self) -> ShrdCtx: + pass + + # pyre-ignore override *input/**kwargs + def forward(self, ids: torch.Tensor) -> LazyAwaitable[Out]: + """ + Perform distributed lookup on the pool using `ids` + + It comprises 3 stages: + + 1) IDs received at each rank must be distributed via all2all to the correct ranks. + 2) Each rank receives the correct IDs, and looks up the values locally + 3) Each rank distributes the values from local lookup to other ranks. Note that this step depends on IDs dist because we need to know the batch dimension of tensors to send to all other ranks. + + Refer to docstring for `ShardedTensorPool` and `ShardedKeyedJaggedTensorPool` for examples. + """ + torch._assert(is_integer_dtype(ids.dtype), "ids type must be int") + + ctx = self.create_context() + id_dist = self._lookup_ids_dist(ctx, ids).wait().wait() + local_lookup = self._lookup_local(ctx, id_dist) + dist_values = self._lookup_values_dist(ctx, local_lookup) + return dist_values + + def lookup(self, ids: torch.Tensor) -> LazyAwaitable[Out]: + return self.forward(ids) + + def update(self, ids: torch.Tensor, values: Out) -> None: + """ + Perform distributed update on the pool mapping `ids` to `values` + + Args: + ids (torch.Tensor): 1D tensor containing ids to be updated + values (torch.Tensor): tensor where first dim must equal number of ids + + It comprises 4 stages: + + 1) Optional preproc stage for the values tensor received + 2) Distribute IDs to correct ranks + 3) Distribute value tensor/KJT to correct ranks + 4) Each rank will now have the IDs to update and the corresponding values tensor/KJT, and can update locally + + Refer to docstring for `ShardedTensorPool` and `ShardedKeyedJaggedTensorPool` for examples. + """ + torch._assert(is_integer_dtype(ids.dtype), "ids type must be int") + values = self._update_preproc(values=values) + ctx = self.create_context() + dist_ids = self._update_ids_dist(ctx=ctx, ids=ids).wait().wait() + dist_values = self._update_values_dist(ctx=ctx, values=values).wait() + self._update_local(ctx=ctx, ids=dist_ids, values=dist_values) + + # These below aren't used, instead we have lookup_ids_dist/lookup_local/lookup_values_dist, and corresponding update + def input_dist( + self, + ctx: ShrdCtx, + # pyre-ignore[2] + *input, + # pyre-ignore[2] + **kwargs, + # pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit return + # value of `None`. + ) -> Awaitable[Awaitable[torch.Tensor]]: + pass + + # pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`. + def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut: + pass + + # pyre-fixme[7]: Expected `LazyAwaitable[Out]` but got implicit return value of + # `None`. + def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]: + pass + + @property + def unsharded_module_type(self) -> Type[ObjectPool[Out]]: + return ObjectPool[Out] diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index 95ade7712..efd06bf02 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Planner The planner provides the specifications necessary for a module to be sharded, diff --git a/torchrec/distributed/planner/benchmark/sparsenn_planner_model.py b/torchrec/distributed/planner/benchmark/sparsenn_planner_model.py deleted file mode 100644 index 6d660adab..000000000 --- a/torchrec/distributed/planner/benchmark/sparsenn_planner_model.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import logging -from typing import cast, List - -import torch - -from torch import nn - -from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder - -from torchrec.distributed.planner.parallelized_planners import ( - ParallelizedEmbeddingShardingPlanner, -) -from torchrec.distributed.planner.planners import EmbeddingShardingPlanner - -from torchrec.distributed.planner.types import Topology -from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder -from torchrec.modules.embedding_configs import EmbeddingBagConfig - -parser = argparse.ArgumentParser(description="custom model for running planner") - -parser.add_argument( - "-lws", - "--local_world_size", - type=int, - default=8, - help="local_world_size; local world size used in topolgy. Defaults to 8", - required=False, -) -parser.add_argument( - "-ws", - "--world_size", - type=int, - default=16, - help="world_size; number of ranks used in topology. Defaults to 16", - required=False, -) -parser.add_argument( - "-bs", - "--batch_size", - type=int, - default=32, - help="batch_size; batch_size used in topology. Defaults to 32", - required=False, -) -parser.add_argument( - "-hc", - "--hbm_cap", - type=int, - default=16777216, - help="hbm_cap; maximum storage used in topology. Defaults to 1024 * 1024 * 16", - required=False, -) -parser.add_argument( - "-cd", - "--compute_device", - type=str, - default="cuda", - help="compute_device; compute_device used in topology. Defaults to 'cuda'", - required=False, -) -parser.add_argument( - "-ne", - "--num_embeddings", - type=int, - default=100, - help="num_embeddings, number of embeddings used in creating tables. Defaults to 100", - required=False, -) -parser.add_argument( - "-ed", - "--embedding_dim", - type=int, - default=64, - help="embedding_dim: embedding dimension used in creating tables. Defaults to 64", - required=False, -) -parser.add_argument( - "-nt", - "--num_tables", - type=int, - default=10, - help="num_tables: number of tables used in creating tables. Defaults to 10", - required=False, -) -parser.add_argument( - "-pt", - "--planner_type", - type=str, - default="parallelized", - help="embedding_sharding_planner_type: type of embedding sharding planner used in creating a planner" - "if need to use non_parallelized, type 'non_parallelized', otherwise defaults to parallelized", - required=False, -) - -args: argparse.Namespace = parser.parse_args() - -logging.basicConfig(level=logging.INFO) - - -def main() -> None: - """ - Generates the sharding plan for a SparseNN model. - - Purpose behind this function is to test planners quickly. This can be done by building the function with custom parameters - such as local_world_size, num_embeddings, num_tables and more. - - Program outputs planner summary. - """ - topology = Topology( - local_world_size=args.local_world_size, - world_size=args.world_size, - hbm_cap=args.hbm_cap, - compute_device=args.compute_device, - ) - - if args.planner_type == "non_parallelized": - planner = EmbeddingShardingPlanner( - topology=topology, batch_size=args.batch_size - ) - else: - planner = ParallelizedEmbeddingShardingPlanner( - topology=topology, batch_size=args.batch_size - ) - - tables: List[EmbeddingBagConfig] = [ - EmbeddingBagConfig( - num_embeddings=args.num_embeddings, - embedding_dim=args.embedding_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(args.num_tables) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - - Sharders: List[ModuleSharder[nn.Module]] = [ - cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder()), - ] - - planner.plan( - module=model, - sharders=Sharders, - ) - - -if __name__ == "__main__": - main() diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index 2a045e685..7dd86b060 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Optional from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -16,6 +18,7 @@ MIN_CW_DIM: int = 128 POOLING_FACTOR: float = 1.0 +NUM_POOLINGS: float = 1.0 BIGINT_DTYPE: int = 8 @@ -23,6 +26,9 @@ DDR_CAP: int = 128 * 1024 * 1024 * 1024 # 128 GB DDR_MEM_BW: float = 51 * 1024 * 1024 * 1024 / 1000 # bytes/ms HBM_MEM_BW: float = 897 * 1024 * 1024 * 1024 / 1000 # bytes/ms +# This can be smaller than DDR_MEM_BW because the PCI channel maybe shared +# with other devices such as the FE NIC. +HBM_TO_DDR_MEM_BW: float = 32 * 1024 * 1024 * 1024 / 1000 # bytes/ms UVM_CACHING_RATIO: float = 0.2 BATCH_SIZE: int = 512 @@ -31,6 +37,7 @@ HALF_BLOCK_PENALTY: float = 1.15 # empirical studies QUARTER_BLOCK_PENALTY: float = 1.75 # empirical studies BWD_COMPUTE_MULTIPLIER: float = 2 # empirical studies +WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER: float = 1 # empirical studies WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies @@ -38,7 +45,11 @@ def kernel_bw_lookup( compute_device: str, compute_kernel: str, + hbm_mem_bw: float, + ddr_mem_bw: float, + hbm_to_ddr_mem_bw: float, caching_ratio: Optional[float] = None, + prefetch_pipeline: bool = False, ) -> Optional[float]: """ Calculates the device bandwidth based on given compute device, compute kernel, and @@ -47,31 +58,49 @@ def kernel_bw_lookup( Args: compute_kernel (str): compute kernel. compute_device (str): compute device. + hbm_mem_bw (float): the bandwidth of the device HBM. + ddr_mem_bw (float): the bandwidth of the system DDR memory. + hbm_to_ddr_bw (float): the bandwidth between device HBM and system DDR. caching_ratio (Optional[float]): caching ratio used to determine device bandwidth if UVM caching is enabled. + prefetch_pipeline (bool): whether prefetch pipeline is enabled. Returns: - float: the device bandwidth. + Optional[float]: the device bandwidth. """ - caching_ratio = caching_ratio if caching_ratio else UVM_CACHING_RATIO + caching_ratio = caching_ratio if caching_ratio is not None else UVM_CACHING_RATIO lookup = { # CPU - ("cpu", EmbeddingComputeKernel.DENSE.value): 0.5 * DDR_MEM_BW, - ("cpu", EmbeddingComputeKernel.FUSED.value): 1 * DDR_MEM_BW, - ("cpu", EmbeddingComputeKernel.QUANT.value): 1 * DDR_MEM_BW, + ("cpu", EmbeddingComputeKernel.DENSE.value): 0.5 * ddr_mem_bw, + ("cpu", EmbeddingComputeKernel.FUSED.value): 1 * ddr_mem_bw, + ("cpu", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw, + # TODO: Determine the correct value later. MTIA uses values same as CPU's. + # MTIA + ("mtia", EmbeddingComputeKernel.DENSE.value): 0.5 * ddr_mem_bw, + ("mtia", EmbeddingComputeKernel.FUSED.value): 1 * ddr_mem_bw, + ("mtia", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw, # CUDA - ("cuda", EmbeddingComputeKernel.DENSE.value): 0.5 * HBM_MEM_BW, - ("cuda", EmbeddingComputeKernel.FUSED.value): 1 * HBM_MEM_BW, - ("cuda", EmbeddingComputeKernel.FUSED_UVM.value): DDR_MEM_BW / 10, + ("cuda", EmbeddingComputeKernel.DENSE.value): 0.5 * hbm_mem_bw, + ("cuda", EmbeddingComputeKernel.FUSED.value): 1 * hbm_mem_bw, + ("cuda", EmbeddingComputeKernel.FUSED_UVM.value): hbm_to_ddr_mem_bw / 10, ("cuda", EmbeddingComputeKernel.FUSED_UVM_CACHING.value): ( - caching_ratio * HBM_MEM_BW + (1 - caching_ratio) * DDR_MEM_BW + caching_ratio * hbm_mem_bw + (1 - caching_ratio) * hbm_to_ddr_mem_bw ) / 10, - ("cuda", EmbeddingComputeKernel.QUANT.value): 1 * HBM_MEM_BW, - ("cuda", EmbeddingComputeKernel.QUANT_UVM.value): DDR_MEM_BW / 10, + ("cuda", EmbeddingComputeKernel.QUANT.value): 1 * hbm_mem_bw, + ("cuda", EmbeddingComputeKernel.QUANT_UVM.value): hbm_to_ddr_mem_bw / 10, ("cuda", EmbeddingComputeKernel.QUANT_UVM_CACHING.value): ( - caching_ratio * HBM_MEM_BW + (1 - caching_ratio) * DDR_MEM_BW + caching_ratio * hbm_mem_bw + (1 - caching_ratio) * hbm_to_ddr_mem_bw ) / 10, + ("cuda", EmbeddingComputeKernel.KEY_VALUE.value): hbm_to_ddr_mem_bw, } + + if ( + prefetch_pipeline + and compute_device == "cuda" + and compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ): + return lookup.get(("cuda", EmbeddingComputeKernel.FUSED.value)) + return lookup.get((compute_device, compute_kernel)) diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 112325e0e..66ea9ee2d 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -5,12 +5,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Set, Tuple, Union from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner.constants import MIN_CW_DIM, POOLING_FACTOR +from torchrec.distributed.planner.constants import POOLING_FACTOR from torchrec.distributed.planner.shard_estimators import ( EmbeddingPerfEstimator, EmbeddingStorageEstimator, @@ -26,12 +28,24 @@ ) from torchrec.distributed.planner.utils import sharder_name from torchrec.distributed.sharding_plan import calculate_shard_sizes_and_offsets -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheParams, + KeyValueParams, + ModuleSharder, + ShardingType, +) +from torchrec.modules.embedding_configs import DataType from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection logger: logging.Logger = logging.getLogger(__name__) +# compute kernels that should only be used if users specified them +GUARDED_COMPUTE_KERNELS: Set[EmbeddingComputeKernel] = { + EmbeddingComputeKernel.KEY_VALUE +} + class EmbeddingEnumerator(Enumerator): """ @@ -43,6 +57,8 @@ class EmbeddingEnumerator(Enumerator): batch_size (int): batch size. constraints (Optional[Dict[str, ParameterConstraints]]): dict of parameter names to provided ParameterConstraints. + estimator (Optional[Union[ShardEstimator, List[ShardEstimator]]]): shard performance estimators. + use_exact_enumerate_order (bool): whether to enumerate shardable parameters in the exact name_children enumeration order """ def __init__( @@ -51,12 +67,26 @@ def __init__( batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, + use_exact_enumerate_order: Optional[bool] = False, ) -> None: self._compute_device: str = topology.compute_device self._world_size: int = topology.world_size self._local_world_size: int = topology.local_world_size self._batch_size: int = batch_size self._constraints = constraints + self._sharder_map: Dict[str, ModuleSharder[nn.Module]] = {} + self._use_exact_enumerate_order: bool = ( + use_exact_enumerate_order if use_exact_enumerate_order else False + ) + memory_type = "hbm_cap" if topology.compute_device == "cuda" else "ddr_cap" + self._device_memory_sizes: Optional[ + List[int] + ] = ( # only used with custom topology where memory is different within a topology + topology._custom_topology_data.get_data(memory_type) + if topology._custom_topology_data + and topology._custom_topology_data.has_data(memory_type) + else None + ) if estimator: self._estimators: List[ShardEstimator] = ( @@ -84,16 +114,19 @@ def enumerate( List[ShardingOption]: valid sharding options with values populated. """ - sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + self._sharder_map = { sharder_name(sharder.module_type): sharder for sharder in sharders } sharding_options: List[ShardingOption] = [] named_modules_queue = [("", module)] while named_modules_queue: - child_path, child_module = named_modules_queue.pop() + if not self._use_exact_enumerate_order: + child_path, child_module = named_modules_queue.pop() + else: + child_path, child_module = named_modules_queue.pop(0) sharder_key = sharder_name(type(child_module)) - sharder = sharder_map.get(sharder_key, None) + sharder = self._sharder_map.get(sharder_key, None) if not sharder: for n, m in child_module.named_children(): if child_path != "": @@ -102,25 +135,43 @@ def enumerate( named_modules_queue.append((n, m)) continue + # Determine the pooling state for all sharding_options using this + # (child_module, child_path). With this optimization, we change enumerate() + # from being O(N^2) with respect to the number of tables to O(N). The + # previous quadratic behavior is because in populate_estimates() invoked below, each + # sharding_option needs to determine its pooling state, which is does via + # an expensive O(N) walk through the list of embedding tables. With this + # change sharding_option.is_pooled becomes O(1). + is_pooled = ShardingOption.module_pooled(child_module, child_path) + for name, param in sharder.shardable_parameters(child_module).items(): + ( + input_lengths, + col_wise_shard_dim, + cache_params, + enforce_hbm, + stochastic_rounding, + bounds_check_mode, + feature_names, + output_dtype, + device_group, + key_value_params, + ) = _extract_constraints_for_param(self._constraints, name) + + # skip for other device groups + if device_group and device_group != self._compute_device: + continue + + sharding_options_per_table: List[ShardingOption] = [] + for sharding_type in self._filter_sharding_types( name, sharder.sharding_types(self._compute_device) ): for compute_kernel in self._filter_compute_kernels( name, sharder.compute_kernels(sharding_type, self._compute_device), + sharding_type, ): - - input_lengths = ( - self._constraints[name].pooling_factors - if self._constraints and self._constraints.get(name) - else [POOLING_FACTOR] - ) - col_wise_shard_dim = ( - self._constraints[name].min_partition - if self._constraints and self._constraints.get(name) - else None - ) ( shard_sizes, shard_offsets, @@ -130,6 +181,7 @@ def enumerate( local_world_size=self._local_world_size, sharding_type=sharding_type, col_wise_shard_dim=col_wise_shard_dim, + device_memory_sizes=self._device_memory_sizes, ) dependency = None if isinstance(child_module, EmbeddingTower): @@ -137,7 +189,7 @@ def enumerate( elif isinstance(child_module, EmbeddingTowerCollection): tower_index = _get_tower_index(name, child_module) dependency = child_path + ".tower_" + str(tower_index) - sharding_options.append( + sharding_options_per_table.append( ShardingOption( name=name, tensor=param, @@ -151,54 +203,94 @@ def enumerate( Shard(size=size, offset=offset) for size, offset in zip(shard_sizes, shard_offsets) ], + cache_params=cache_params, + enforce_hbm=enforce_hbm, + stochastic_rounding=stochastic_rounding, + bounds_check_mode=bounds_check_mode, dependency=dependency, + is_pooled=is_pooled, + feature_names=feature_names, + output_dtype=output_dtype, + key_value_params=key_value_params, ) ) - if not sharding_options: + if not sharding_options_per_table: raise RuntimeError( "No available sharding type and compute kernel combination " - f"after applying user provided constraints for {name}" + f"after applying user provided constraints for {name}. " + f"Module: {sharder_key}, sharder: {sharder.__class__.__name__}, compute device: {self._compute_device}. " + f"To debug, search above for warning logs about no available sharding types/compute kernels for table: {name}" ) - for estimator in self._estimators: - estimator.estimate(sharding_options, sharder_map) + sharding_options.extend(sharding_options_per_table) + + self.populate_estimates(sharding_options) return sharding_options - def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]: + def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: + for estimator in self._estimators: + estimator.estimate(sharding_options, self._sharder_map) + + def _filter_sharding_types( + self, name: str, allowed_sharding_types: List[str] + ) -> List[str]: + # GRID_SHARD is only supported if specified by user in parameter constraints if not self._constraints or not self._constraints.get(name): - return sharding_types + return [ + t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value + ] constraints: ParameterConstraints = self._constraints[name] if not constraints.sharding_types: - return sharding_types + return [ + t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value + ] constrained_sharding_types: List[str] = constraints.sharding_types - sharding_types = list(set(constrained_sharding_types) & set(sharding_types)) - - if not sharding_types: + filtered_sharding_types = list( + set(constrained_sharding_types) & set(allowed_sharding_types) + ) + if not filtered_sharding_types: logger.warn( - f"No available sharding types after applying user provided constraints for {name}" + "No available sharding types after applying user provided " + f"constraints for {name}. Constrained sharding types: " + f"{constrained_sharding_types}, allowed sharding types: " + f"{allowed_sharding_types}, filtered sharding types: " + f"{filtered_sharding_types}. Please check if the constrained " + "sharding types are too restrictive, if the sharder allows the " + "sharding types, or if non-strings are passed in." ) - return sharding_types + return filtered_sharding_types def _filter_compute_kernels( self, name: str, - compute_kernels: List[str], + allowed_compute_kernels: List[str], + sharding_type: str, ) -> List[str]: - - if not self._constraints or not self._constraints.get(name): - filtered_compute_kernels = compute_kernels + # setup constrained_compute_kernels + if ( + self._constraints + and self._constraints.get(name) + and self._constraints[name].compute_kernels + ): + # pyre-ignore + constrained_compute_kernels: List[str] = self._constraints[ + name + ].compute_kernels else: - constraints: ParameterConstraints = self._constraints[name] - if not constraints.compute_kernels: - filtered_compute_kernels = compute_kernels - else: - constrained_compute_kernels: List[str] = constraints.compute_kernels - filtered_compute_kernels = list( - set(constrained_compute_kernels) & set(compute_kernels) - ) + constrained_compute_kernels: List[str] = [ + compute_kernel.value + for compute_kernel in EmbeddingComputeKernel + if compute_kernel not in GUARDED_COMPUTE_KERNELS + ] + + # setup filtered_compute_kernels + filtered_compute_kernels = list( + set(constrained_compute_kernels) & set(allowed_compute_kernels) + ) + # special rules if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels: if ( EmbeddingComputeKernel.FUSED.value in filtered_compute_kernels @@ -207,11 +299,68 @@ def _filter_compute_kernels( if not filtered_compute_kernels: logger.warn( - f"No available compute kernels after applying user provided constraints for {name}" + "No available compute kernels after applying user provided " + f"constraints for {name}. Constrained compute kernels: " + f"{constrained_compute_kernels}, allowed compute kernels: " + f"{allowed_compute_kernels}, filtered compute kernels: " + f"{filtered_compute_kernels}, sharding type: {sharding_type}. Please check if the constrained " + "compute kernels are too restrictive, if the sharder allows the " + "compute kernels, or if non-strings are passed in." ) return filtered_compute_kernels +def _extract_constraints_for_param( + constraints: Optional[Dict[str, ParameterConstraints]], name: str +) -> Tuple[ + List[float], + Optional[int], + Optional[CacheParams], + Optional[bool], + Optional[bool], + Optional[BoundsCheckMode], + Optional[List[str]], + Optional[DataType], + Optional[str], + Optional[KeyValueParams], +]: + input_lengths = [POOLING_FACTOR] + col_wise_shard_dim = None + cache_params = None + enforce_hbm = None + stochastic_rounding = None + bounds_check_mode = None + feature_names = None + output_dtype = None + device_group = None + key_value_params = None + + if constraints and constraints.get(name): + input_lengths = constraints[name].pooling_factors + col_wise_shard_dim = constraints[name].min_partition + cache_params = constraints[name].cache_params + enforce_hbm = constraints[name].enforce_hbm + stochastic_rounding = constraints[name].stochastic_rounding + bounds_check_mode = constraints[name].bounds_check_mode + feature_names = constraints[name].feature_names + output_dtype = constraints[name].output_dtype + device_group = constraints[name].device_group + key_value_params = constraints[name].key_value_params + + return ( + input_lengths, + col_wise_shard_dim, + cache_params, + enforce_hbm, + stochastic_rounding, + bounds_check_mode, + feature_names, + output_dtype, + device_group, + key_value_params, + ) + + def get_partition_by_type(sharding_type: str) -> str: """ Gets corresponding partition by type for provided sharding type. @@ -235,6 +384,7 @@ def get_partition_by_type(sharding_type: str) -> str: ShardingType.ROW_WISE.value, ShardingType.DATA_PARALLEL.value, } + multi_host_sharding_types = {ShardingType.GRID_SHARD.value} if sharding_type in device_sharding_types: return PartitionByType.DEVICE.value @@ -242,6 +392,8 @@ def get_partition_by_type(sharding_type: str) -> str: return PartitionByType.HOST.value elif sharding_type in uniform_sharding_types: return PartitionByType.UNIFORM.value + elif sharding_type in multi_host_sharding_types: + return PartitionByType.MULTI_HOST.value raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" diff --git a/torchrec/distributed/planner/parallelized_planners.py b/torchrec/distributed/planner/parallelized_planners.py deleted file mode 100644 index 52173ba70..000000000 --- a/torchrec/distributed/planner/parallelized_planners.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import os -from functools import reduce -from multiprocessing.pool import ThreadPool as Pool -from time import perf_counter -from typing import cast, Dict, List, Optional, Tuple, Union - -import numpy - -import torch - -import torch.distributed as dist -from torch import nn -from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result -from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE -from torchrec.distributed.planner.enumerators import EmbeddingEnumerator -from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner -from torchrec.distributed.planner.perf_models import NoopPerfModel -from torchrec.distributed.planner.proposers import ( - GreedyProposer, - GridSearchProposer, - proposers_to_proposals_list, - UniformProposer, -) -from torchrec.distributed.planner.stats import EmbeddingStats -from torchrec.distributed.planner.storage_reservations import ( - HeuristicalStorageReservation, -) -from torchrec.distributed.planner.types import ( - Enumerator, - ParameterConstraints, - Partitioner, - PerfModel, - PlannerError, - PlannerErrorType, - Proposer, - ShardingOption, - Stats, - Storage, - StorageReservation, - Topology, -) -from torchrec.distributed.types import ( - EnumerableShardingSpec, - ModuleSharder, - ParameterSharding, - ShardingPlan, - ShardingPlanner, - ShardingType, - ShardMetadata, -) - - -def _to_sharding_plan( - sharding_options: List[ShardingOption], - topology: Topology, -) -> ShardingPlan: - def _placement( - compute_device: str, - rank: int, - local_size: int, - ) -> str: - param_device = compute_device - if compute_device == "cuda": - param_device = torch.device("cuda", rank % local_size) - return f"rank:{rank}/{param_device}" - - compute_device = topology.compute_device - local_size = topology.local_world_size - - plan = {} - for sharding_option in sharding_options: - shards = sharding_option.shards - sharding_type = sharding_option.sharding_type - - module_plan = plan.get(sharding_option.path, {}) - module_plan[sharding_option.name] = ParameterSharding( - sharding_spec=None - if sharding_type == ShardingType.DATA_PARALLEL.value - else EnumerableShardingSpec( - [ - ShardMetadata( - shard_sizes=shard.size, - shard_offsets=shard.offset, - placement=_placement( - compute_device, cast(int, shard.rank), local_size - ), - ) - for shard in shards - ] - ), - sharding_type=sharding_type, - compute_kernel=sharding_option.compute_kernel, - ranks=[cast(int, shard.rank) for shard in shards], - ) - plan[sharding_option.path] = module_plan - return ShardingPlan(plan) - - -class ParallelizedEmbeddingShardingPlanner(ShardingPlanner): - """ - Provides an optimized sharding plan for a given module with shardable parameters - according to the provided sharders, topology, and constraints using multiprocessing to improve runtime and scalability. - """ - - def __init__( - self, - topology: Topology, - batch_size: Optional[int] = None, - enumerator: Optional[Enumerator] = None, - storage_reservation: Optional[StorageReservation] = None, - proposer: Optional[Union[Proposer, List[Proposer]]] = None, - custom_cpu_count: Optional[int] = None, - partitioner: Optional[Partitioner] = None, - performance_model: Optional[PerfModel] = None, - stats: Optional[Union[Stats, List[Stats]]] = None, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - debug: bool = True, - ) -> None: - self._topology = topology - self._batch_size: int = batch_size if batch_size else BATCH_SIZE - self._constraints = constraints - self._enumerator: Enumerator = ( - enumerator - if enumerator - else EmbeddingEnumerator( - topology=topology, - batch_size=self._batch_size, - constraints=constraints, - ) - ) - self._storage_reservation: StorageReservation = ( - storage_reservation - if storage_reservation - else HeuristicalStorageReservation(percentage=0.15) - ) - self._partitioner: Partitioner = ( - partitioner if partitioner else GreedyPerfPartitioner() - ) - if proposer: - self._proposers: List[Proposer] = ( - [proposer] if not isinstance(proposer, list) else proposer - ) - else: - self._proposers = [ - GridSearchProposer(), - GreedyProposer(), - GreedyProposer(use_depth=False), - UniformProposer(), - ] - self._perf_model: PerfModel = ( - performance_model if performance_model else NoopPerfModel(topology=topology) - ) - if stats: - self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats - else: - self._stats = [EmbeddingStats()] - - self._debug = debug - self._num_proposals: int = 0 - self._num_plans: int = 0 - self._cpu_count: Optional[int] = ( - custom_cpu_count if custom_cpu_count else os.cpu_count() - ) - - def collective_plan( - self, - module: nn.Module, - sharders: List[ModuleSharder[nn.Module]], - pg: dist.ProcessGroup, - ) -> ShardingPlan: - """ - Call self.plan(...) on rank 0 and broadcast - """ - return invoke_on_rank_and_broadcast_result( - pg, - 0, - self.plan, - module, - sharders, - ) - - def plan( - self, - module: nn.Module, - sharders: List[ModuleSharder[nn.Module]], - ) -> ShardingPlan: - - self._num_proposals = 0 - self._num_plans = 0 - start_time = perf_counter() - - storage_constraint: Topology = self._storage_reservation.reserve( - topology=self._topology, - batch_size=self._batch_size, - module=module, - sharders=sharders, - constraints=self._constraints, - ) - - search_space = self._enumerator.enumerate( - module=module, - sharders=sharders, - ) - if not search_space: - # No shardable parameters - return ShardingPlan({}) - - proposals_list = proposers_to_proposals_list( - self._proposers, search_space=search_space - ) - - self._num_proposals = len(proposals_list) - - def get_best_plan( - proposal_group: List[List[ShardingOption]], - ) -> Tuple[Optional[List[ShardingOption]], float, Storage, int]: - - group_plans_num = 0 - lowest_storage = Storage(MAX_SIZE, MAX_SIZE) - best_perf_rating = MAX_SIZE - best_plan = None - - for proposal in proposal_group: - try: - plan = self._partitioner.partition( - proposal=proposal, - storage_constraint=storage_constraint, - ) - group_plans_num += 1 - perf_rating = self._perf_model.rate(plan=plan) - - if perf_rating < best_perf_rating: - best_perf_rating = perf_rating - best_plan = plan - - except PlannerError: - current_storage = cast( - Storage, - reduce( - lambda x, y: x + y, - [ - shard.storage - for option in proposal - for shard in option.shards - ], - ), - ) - if current_storage < lowest_storage: - lowest_storage = current_storage - - return (best_plan, best_perf_rating, lowest_storage, group_plans_num) - - grouped_proposals = numpy.array_split(proposals_list, self._cpu_count) - - pool = Pool(self._cpu_count) - group_best_plans = pool.map(get_best_plan, grouped_proposals) - pool.close() - pool.join() - - lowest_storage = Storage(MAX_SIZE, MAX_SIZE) - best_perf_rating = MAX_SIZE - best_plan = None - - for plan_info in group_best_plans: - current_plan = plan_info[0] - plan_perf_rating = plan_info[1] - plan_storage = plan_info[2] - self._num_plans += plan_info[3] - if plan_perf_rating < best_perf_rating: - best_plan = current_plan - best_perf_rating = plan_perf_rating - if plan_storage < lowest_storage: - lowest_storage = plan_storage - - if best_plan is not None: - sharding_plan = _to_sharding_plan(best_plan, self._topology) - end_time = perf_counter() - for stats in self._stats: - stats.log( - sharding_plan=sharding_plan, - topology=self._topology, - batch_size=self._batch_size, - storage_reservation=self._storage_reservation, - num_proposals=self._num_proposals, - num_plans=self._num_plans, - run_time=end_time - start_time, - best_plan=best_plan, - constraints=self._constraints, - debug=self._debug, - ) - return sharding_plan - else: - global_storage_capacity = reduce( - lambda x, y: x + y, - [device.storage for device in self._topology.devices], - ) - global_storage_constraints = reduce( - lambda x, y: x + y, - [device.storage for device in storage_constraint.devices], - ) - no_plan_solution = ( - f"Planner evaluated {self._num_proposals} proposals." - "\nPossible solutions:" - f"\n 1) Increase the number of devices ({self._topology.world_size})" - f"\n 2) Reduce the model size (" - f"\n\t Global storage: {global_storage_capacity.hbm}, " - f"\n\t Available for model parallel: {global_storage_constraints}," - f"\n\t Requirement for model parallel: {lowest_storage})" - f"\n 3) Reduce local batch size ({self._batch_size})" - "\n 4) Remove planner constraints that might be reducing search space or available storage\n" - ) - if global_storage_constraints < lowest_storage: - raise PlannerError( - error_type=PlannerErrorType.INSUFFICIENT_STORAGE, - message="Unable to find a plan for this model because of insufficient storage. \n" - + no_plan_solution, - ) - else: - raise PlannerError( - error_type=PlannerErrorType.STRICT_CONSTRAINTS, - message="Unable to find a plan for this model because of the strict constraints. \n" - + no_plan_solution, - ) diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index 2ad9ff2b6..b397c5064 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -5,22 +5,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy +import heapq +import itertools +import logging from dataclasses import dataclass -from typing import cast, List +from enum import Enum +from typing import cast, Dict, List, Optional + +from torchrec.distributed.planner.perf_models import NoopPerfModel from torchrec.distributed.planner.types import ( DeviceHardware, PartitionByType, Partitioner, + Perf, + PerfModel, PlannerError, PlannerErrorType, ShardingOption, Storage, Topology, ) +from torchrec.distributed.planner.utils import bytes_to_gb, reset_shard_rank from torchrec.distributed.types import ShardingType +logger: logging.Logger = logging.getLogger(__name__) + def _sort_devices_by_perf( devices: List[List[DeviceHardware]], @@ -28,7 +41,7 @@ def _sort_devices_by_perf( def _get_perf_sum(device_list: List[DeviceHardware]) -> float: perf = 0 for device in device_list: - perf += device.perf + perf += device.perf.total return perf return sorted(devices, key=_get_perf_sum) @@ -48,11 +61,30 @@ def _get_uniform_sharding_options( class ShardingOptionGroup: sharding_options: List[ShardingOption] storage_sum: Storage + perf_sum: float + param_count: int + + +class SortBy(Enum): + STORAGE = "storage" + PERF = "perf" def _group_and_sort_non_uniform_sharding_options( sharding_options: List[ShardingOption], + sort_by: SortBy = SortBy.STORAGE, + balance_modules: bool = False, ) -> List[ShardingOptionGroup]: + + # count modules by name + param_count: Dict[str, int] = {} + for sharding_option in sharding_options: + path = sharding_option.path + if path not in param_count: + param_count[path] = 0 + param_count[path] += 1 + logger.debug(f"param_count is {param_count}") + sharding_option_groups_by_dependency = {} for sharding_option in sharding_options: if sharding_option.partition_by == PartitionByType.UNIFORM.value: @@ -61,7 +93,11 @@ def _group_and_sort_non_uniform_sharding_options( group_key = sharding_option.dependency or sharding_option.fqn if group_key not in sharding_option_groups_by_dependency: sharding_option_groups_by_dependency[group_key] = ShardingOptionGroup( - [sharding_option], sharding_option.total_storage + [sharding_option], + sharding_option.total_storage, + sharding_option.total_perf, + # negative value to indicate that smaller modules should be sorted first + param_count=-param_count[sharding_option.path], ) else: sharding_option_groups_by_dependency[group_key].sharding_options.append( @@ -70,16 +106,66 @@ def _group_and_sort_non_uniform_sharding_options( sharding_option_groups_by_dependency[ group_key ].storage_sum += sharding_option.total_storage + sharding_option_groups_by_dependency[ + group_key + ].perf_sum += sharding_option.total_perf + sharding_option_groups = list(sharding_option_groups_by_dependency.values()) - sharding_option_groups.sort(key=lambda group: group.storage_sum, reverse=True) + sort_by_attributes: List[str] = [] + if balance_modules: + sort_by_attributes.append("param_count") + + if sort_by == SortBy.STORAGE: + sort_by_attributes.append("storage_sum") + elif sort_by == SortBy.PERF: + sort_by_attributes.append("perf_sum") + else: + raise RuntimeError(f"Unexpected sort_by: {sort_by}") + + sharding_option_groups.sort( + key=lambda group: [getattr(group, attr) for attr in sort_by_attributes], + reverse=True, + ) + return sharding_option_groups +@dataclass +class OrderedDeviceHardware: + device: DeviceHardware + local_world_size: int + + def __lt__(self, other: "OrderedDeviceHardware") -> bool: + # Use local rank as a tie breaker to ensure that we don't overload a single + # host's DDR limit. + return ( + self.device.perf.total, + self.device.rank % self.local_world_size, + self.device.rank, + ) < ( + other.device.perf.total, + other.device.rank % self.local_world_size, + other.device.rank, + ) + + class GreedyPerfPartitioner(Partitioner): + """Greedy Partitioner. + + Args: + sort_by (SortBy): Sort sharding options by storage or perf in + descending order (i.e., large tables will be placed first). + balance_modules (bool): Whether to sort by modules first, where + smaller modules will be sorted first. In effect, this will place + tables in each module in a balanced way. """ - Greedy Partitioner - """ + + def __init__( + self, sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False + ) -> None: + self._sort_by = sort_by + self._balance_modules = balance_modules def partition( self, @@ -127,50 +213,64 @@ def partition( # First [sharding_options[0] and sharding_options[1]] will be placed on the # topology with the uniform strategy, resulting in - topology.devices[0].perf = (1,2) - topology.devices[1].perf = (1,2) + topology.devices[0].perf.total = (1,2) + topology.devices[1].perf.total = (1,2) # Finally sharding_options[2] and sharding_options[3]] will be placed on the # topology with the device strategy (see docstring of `partition_by_device` for # more details). - topology.devices[0].perf = (1,2) + (3,4) - topology.devices[1].perf = (1,2) + (3,4) + topology.devices[0].perf.total = (1,2) + (3,4) + topology.devices[1].perf.total = (1,2) + (3,4) # The topology updates are done after the end of all the placements (the other # in the example is just for clarity). """ _topology: Topology = copy.deepcopy(storage_constraint) - _host_level_devices = GreedyPerfPartitioner._get_host_level_devices(_topology) + minheap_devices: Optional[List[OrderedDeviceHardware]] = None + _host_level_devices = self._get_host_level_devices(_topology) # first partition the uniform sharding options (RW & DP) uniform_sharding_options = _get_uniform_sharding_options(proposal) - GreedyPerfPartitioner._uniform_partition( - uniform_sharding_options, _topology.devices - ) + self._uniform_partition(uniform_sharding_options, _topology.devices) # group the rest sharding options by colocation type (co-host, co-device, none) # and sort the groups by storage in reverse order - sharding_option_groups = _group_and_sort_non_uniform_sharding_options(proposal) + sharding_option_groups = _group_and_sort_non_uniform_sharding_options( + proposal, sort_by=self._sort_by, balance_modules=self._balance_modules + ) for sharding_option_group in sharding_option_groups: if ( + sharding_option_group.sharding_options[0].partition_by + == PartitionByType.MULTI_HOST.value + ): + self._multi_hosts_partition(sharding_option_group, _host_level_devices) + # _multi_hosts_partition invalidates minheap_devices, force rebuild before using + minheap_devices = None + + elif ( sharding_option_group.sharding_options[0].partition_by == PartitionByType.HOST.value ): - GreedyPerfPartitioner._cohost_partition( - sharding_option_group, _host_level_devices - ) + self._cohost_partition(sharding_option_group, _host_level_devices) + # _cohost_partition invalidates minheap_devices, force rebuild before using + minheap_devices = None elif ( sharding_option_group.sharding_options[0].partition_by == PartitionByType.DEVICE.value ): + if minheap_devices is None: + minheap_devices = self._establish_minheap( + _topology.devices, _topology.local_world_size + ) assert ( len(sharding_option_group.sharding_options) == 1 ), f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}" - GreedyPerfPartitioner._device_partition( - sharding_option_group.sharding_options[0], _topology.devices + self._device_partition( + sharding_option_group.sharding_options[0], + minheap_devices, ) else: raise RuntimeError( @@ -180,28 +280,189 @@ def partition( self._topology: Topology = _topology return proposal - @staticmethod + @classmethod + def _establish_minheap( + cls, devices: List[DeviceHardware], local_world_size: int + ) -> List[OrderedDeviceHardware]: + minheap_devices = [ + OrderedDeviceHardware(device, local_world_size) for device in devices + ] + heapq.heapify(minheap_devices) + return minheap_devices + + @classmethod def _device_partition( - sharding_option: ShardingOption, devices: List[DeviceHardware] + cls, + sharding_option: ShardingOption, + minheap_devices: List[OrderedDeviceHardware], + bulk_heapify_threshold: float = 0.25, ) -> None: + pushlimit = len(minheap_devices) * bulk_heapify_threshold for shard in sharding_option.shards: - sorted_devices = sorted(devices, key=lambda device: device.perf) - success = False - for device in sorted_devices: - if cast(Storage, shard.storage).fits_in(device.storage): + tmp_heap = [] + while minheap_devices: + ordered_device = minheap_devices[0] + device = ordered_device.device + storage = cast(Storage, shard.storage) + if storage.fits_in(device.storage): shard.rank = device.rank device.storage -= cast(Storage, shard.storage) - device.perf += cast(float, shard.perf) - success = True + device.perf += cast(Perf, shard.perf) + heapq.heapreplace(minheap_devices, ordered_device) break - if not success: + else: + heapq.heappop(minheap_devices) + tmp_heap.append(ordered_device) + else: raise PlannerError( error_type=PlannerErrorType.PARTITION, - message=f"Device partition failed. Couldn't find a rank for shard {shard}, devices: {devices}", + message=( + f"Device partition failed. Couldn't find a rank for shard {shard} of table {sharding_option.name}, " + f"largest device storage: {max(ordered_device.device.storage for ordered_device in tmp_heap)}" + ), + ) + if tmp_heap: + # restablish minheap + if len(tmp_heap) <= pushlimit: + for ordered_device in tmp_heap: + heapq.heappush(minheap_devices, ordered_device) + else: + minheap_devices.extend(tmp_heap) + heapq.heapify(minheap_devices) + + @classmethod + def _multi_hosts_partition( + cls, + sharding_option_group: ShardingOptionGroup, + _host_level_devices: List[List[DeviceHardware]], + ) -> None: + """ + Partition shards on multiple hosts. This is a greedy algorithm trying to complete partitioning on multiple hosts (sorted by perf). + First we do columnwise sharding among hosts, then tablewise-rowwise sharding within each host. + There're two cases depends on the number of hosts needed to partition shards. + + Case one: `num_host_to_allocate >= len(sorted_host_level_devices)` + We'll try to partition only once. Hosts might be selected multiple times in a circular manner. + E.g, we have 3 hosts and `num_host_to_allocate` = 4. We sort all devices on host level. The devices of hosts [0, 1, 2, 0] will be selected for uniform partitioning. + We'll update device information if success, otherwise raise a `PlannerError`. + + Case two: `num_host_to_allocate < len(sorted_host_level_devices)` + We'll try to partition with hosts `[host_index, host_index + num_host_to_allocate]` iteratively with host_index incremented by 1 each time. + 1) We sort all devices on host level. Set `host_index` = 0 + 2) We select hosts`[host_index, host_index + num_host_to_allocate]` if indexes are within range. + 3) We do uniform partitioning over all devices of the selected hosts. If we cannot partition, then we increase `host_index` by 1 and go to 2); Otherwise we go to 4) + 4) Update device information if success, otherwise raise a `PlannerError`. + + Keyword arguments: + sharding_option_group -- grouped sharding options + _host_level_devices -- devices + + Example:: + sharding_option_group.sharding_options = [ + ShardingOption(partition_by="multi_host", + shards=[ + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + ]), + ] + topology = Topology(world_size=6, local_world_size=2) + + # sharding_options[0] will be placed on host 1 and host 2 with the multi_hosts strategy, resulting in + + topology.devices[0].perf.total = (1,1) + topology.devices[1].perf.total = (1,1) + topology.devices[2].perf.total = (1,1) + topology.devices[3].perf.total = (1,1) + topology.devices[4].perf.total = (0,0) + topology.devices[5].perf.total = (0,0) + + """ + # TODO: for now assume just one option for multi_hosts. + if len(sharding_option_group.sharding_options) != 1: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}. Length needs to be 1", + ) + num_shards = sharding_option_group.sharding_options[0].num_shards + + if _host_level_devices is None: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message="host level devices is None", + ) + + local_world_size = len(_host_level_devices[0]) + num_host_to_allocate, remainder = divmod(num_shards, local_world_size) + + if remainder > 0: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"Grid Sharding is unable to place shards equally over hosts without overlapping. {num_shards=} % {local_world_size=} != 0", + ) + + sorted_host_level_devices = _sort_devices_by_perf(_host_level_devices) + host_index = 0 + all_hosts_used = False + while True: + if num_host_to_allocate >= len(sorted_host_level_devices): + # case one: we need to use all hosts + all_hosts_used = True + devices = [] + for i in range(num_host_to_allocate): + devices.extend( + sorted_host_level_devices[i % len(sorted_host_level_devices)] + ) + else: + # case two: we can use some hosts + devices = list( + itertools.chain( + *sorted_host_level_devices[ + host_index : host_index + num_host_to_allocate + ] + ) ) + host_index += 1 # shift to next host + host_devices = copy.deepcopy(devices) + success = True + sharding_option = sharding_option_group.sharding_options[0] + try: + if sharding_option.sharding_type == ShardingType.GRID_SHARD.value: + cls._uniform_partition([sharding_option], host_devices) + else: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"unexpected multi_host sharding type: {sharding_option.sharding_type}", + ) + except PlannerError: + success = False + if success: + # successfully found some hosts and partitioned on these hosts + # need to update the devices + for device, host_device in zip(devices, host_devices): + # check that devices and host_devices are in the same order + if device.rank != host_device.rank: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"device rank {device.rank} is not the same as device_copy rank {host_device.rank}", + ) + device.storage = host_device.storage + device.perf = host_device.perf + return + + if ( + host_index + num_host_to_allocate > len(sorted_host_level_devices) + ) or all_hosts_used: + break + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"can't find hosts for sharding option group {sharding_option_group}", + ) - @staticmethod + @classmethod def _cohost_partition( + cls, sharding_option_group: ShardingOptionGroup, _host_level_devices: List[List[DeviceHardware]], ) -> None: @@ -215,25 +476,30 @@ def _cohost_partition( continue success = True + minheap_devices: Optional[List[OrderedDeviceHardware]] = None for sharding_option in sharding_option_group.sharding_options: try: if ( sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value ): - GreedyPerfPartitioner._uniform_partition( - [sharding_option], host_devices - ) + cls._uniform_partition([sharding_option], host_devices) + # _uniform_partition invalidates minheap_devices, force rebuild + # before using + minheap_devices = None elif ( sharding_option.sharding_type == ShardingType.TABLE_COLUMN_WISE.value ): - GreedyPerfPartitioner._device_partition( - sharding_option, host_devices - ) + if minheap_devices is None: + minheap_devices = cls._establish_minheap( + host_devices, len(host_devices) + ) + cls._device_partition(sharding_option, minheap_devices) else: - raise RuntimeError( - f"unexpected cohost sharding type: {sharding_option.sharding_type}" + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"unexpected cohost sharding type: {sharding_option.sharding_type}", ) except PlannerError: success = False @@ -241,6 +507,8 @@ def _cohost_partition( if success: # successfully found a host and partitioned on that host # need to update the devices + # resorting host_devices before copying data back + host_devices.sort(key=lambda device: device.rank) for device, device_copy in zip(devices, host_devices): device.storage = device_copy.storage device.perf = device_copy.perf @@ -250,8 +518,8 @@ def _cohost_partition( message=f"can't find a host for sharding option group {sharding_option_group}", ) - @staticmethod - def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]: + @classmethod + def _get_host_level_devices(cls, _topology: Topology) -> List[List[DeviceHardware]]: num_hosts: int = _topology.world_size // _topology.local_world_size host_level_devices: List[List[DeviceHardware]] = [] for i in range(num_hosts): @@ -261,14 +529,15 @@ def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]: host_level_devices.append(devices_in_host) return host_level_devices - @staticmethod + @classmethod def _uniform_partition( - sharding_options: List[ShardingOption], devices: List[DeviceHardware] + cls, sharding_options: List[ShardingOption], devices: List[DeviceHardware] ) -> None: for sharding_option in sharding_options: if sharding_option.num_shards != len(devices): - raise RuntimeError( - f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})" + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})", ) for i in range(len(devices)): storage_needed = cast(Storage, sharding_option.shards[i].storage) @@ -280,4 +549,125 @@ def _uniform_partition( else: sharding_option.shards[i].rank = devices[i].rank devices[i].storage -= storage_needed - devices[i].perf += cast(float, sharding_option.shards[i].perf) + devices[i].perf += cast(Perf, sharding_option.shards[i].perf) + + +class MemoryBalancedPartitioner(Partitioner): + """Memory balanced Partitioner. + + Args: + max_search_count (int): Maximum number of times to call the + GreedyPartitioner. + tolerance (float): The maximum acceptable difference between the + original plan and the new plan. If tolerance is 1, that means a new + plan will be rejected if its perf is 200% of the original plan + (i.e., the plan is 100% worse). + balance_modules (bool): Whether to sort by modules first, where + smaller modules will be sorted first. In effect, this will place + tables in each module in a balanced way. + """ + + def __init__( + self, + max_search_count: int = 10, + tolerance: float = 0.02, + balance_modules: bool = False, + ) -> None: + self._max_search_count: int = max_search_count + self._tolerance: float = tolerance + self._balance_modules: bool = balance_modules + + def partition( + self, + proposal: List[ShardingOption], + storage_constraint: Topology, + ) -> List[ShardingOption]: + """ + Repeatedly calls the GreedyPerfPartitioner to find a plan with perf + within the tolerance of the original plan that uses the least amount + of memory. + """ + _perf_model: PerfModel = NoopPerfModel(storage_constraint) + _partitioner = GreedyPerfPartitioner( + sort_by=SortBy.PERF, balance_modules=self._balance_modules + ) + # copying storage_constraint, since we modify it in place + _topology: Topology = copy.deepcopy(storage_constraint) + + # set up default plan to fall back on + default_plan = _partitioner.partition(proposal, _topology) + default_plan = copy.deepcopy(default_plan) + original_plan_perf = _perf_model.rate(default_plan) + + max_hbm_per_device: int = _topology.devices[0].storage.hbm + logger.info( + f"Default plan uses {round(bytes_to_gb(max_hbm_per_device), 3)} GB per device." + ) + + hbm_requirement: int = 0 + for sharding_option in proposal: + for shard in sharding_option.shards: + if shard.storage is not None: + hbm_requirement += shard.storage.hbm + min_hbm_per_device: int = int(hbm_requirement / _topology.world_size) + logger.info( + "Searching in the range (min_hbm_per_device, max_hbm_per_device): " + f"({round(bytes_to_gb(min_hbm_per_device), 3)}, " + f"{round(bytes_to_gb(max_hbm_per_device), 3)})" + ) + + # binary search with (min, max] setting + search_count = 0 + while ( + search_count < self._max_search_count + and min_hbm_per_device + 10 * 1024**2 < max_hbm_per_device # 10MB + ): + search_count += 1 + reset_shard_rank(proposal) + mid_hbm_per_device: int = (max_hbm_per_device + min_hbm_per_device) // 2 + set_hbm_per_device(_topology, mid_hbm_per_device) + try: + new_plan = _partitioner.partition(proposal, _topology) + new_plan_perf = _perf_model.rate(new_plan) + perf_diff = ( + (new_plan_perf - original_plan_perf) / original_plan_perf + if original_plan_perf + else 100 + ) + if new_plan_perf > original_plan_perf * (1 + self._tolerance): + # the new plan is worse than the original one + logger.info( + f"Found a plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} " + f"GB per device for embedding tables, " + f"but its perf is {round(perf_diff * 100, 3)}% worse than the original plan, " + f"which exceeds the {self._tolerance * 100}% tolerance." + ) + min_hbm_per_device = mid_hbm_per_device + else: + # the new plan is better than original one + if perf_diff > 0: + perf_diff_str = ( + f"{round((perf_diff) * 100, 3)}% worse than the original plan, " + f"which is within the {self._tolerance * 100}% tolerance." + ) + else: + perf_diff_str = f"{round((perf_diff) * 100, 3)}% better than the original plan." + logger.info( + f"Found a more memory-balanced plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} " + f"GB per device for embedding tables. The new plan is {perf_diff_str}" + ) + default_plan = copy.deepcopy(new_plan) + max_hbm_per_device = mid_hbm_per_device + except PlannerError: + logger.info( + f"Couldn't find a plan with {round(bytes_to_gb(max_hbm_per_device), 3)} " + f"GB per device for embedding tables." + ) + min_hbm_per_device = mid_hbm_per_device + + return default_plan + + +def set_hbm_per_device(storage_constraint: Topology, hbm_per_device: int) -> None: + for device in storage_constraint.devices: + device.storage.hbm = hbm_per_device diff --git a/torchrec/distributed/planner/perf_models.py b/torchrec/distributed/planner/perf_models.py index 0f30d5d08..c52087379 100644 --- a/torchrec/distributed/planner/perf_models.py +++ b/torchrec/distributed/planner/perf_models.py @@ -5,12 +5,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import cast, List -from torchrec.distributed.planner.types import PerfModel, ShardingOption, Topology +from torchrec.distributed.planner.types import ( + Perf, + PerfModel, + ShardingOption, + Storage, + Topology, +) class NoopPerfModel(PerfModel): + """ + A no-op model that returns the maximum perf among all shards. Here, no-op + means we estimate the performance of a model without actually running it. + """ + def __init__(self, topology: Topology) -> None: self._topology = topology @@ -19,6 +32,25 @@ def rate(self, plan: List[ShardingOption]) -> float: for sharding_option in plan: for shard in sharding_option.shards: # pyre-ignore [6]: Expected `typing_extensions.SupportsIndex` - perfs[shard.rank] += cast(float, shard.perf) + perfs[shard.rank] += cast(Perf, shard.perf).total return max(perfs) + + +class NoopStorageModel(PerfModel): + """ + A no-op model that returns the maximum hbm usage among all shards. Here, no-op + means we estimate the performance of a model without actually running it. + """ + + def __init__(self, topology: Topology) -> None: + self._topology = topology + + def rate(self, plan: List[ShardingOption]) -> float: + hbms = [0] * self._topology.world_size + for sharding_option in plan: + for shard in sharding_option.shards: + # pyre-ignore [6]: Expected `typing_extensions.SupportsIndex` + hbms[shard.rank] += cast(Storage, shard.storage).hbm + + return max(hbms) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index d51c28c80..a7cd84794 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -5,17 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy +import logging +import time from functools import reduce from time import perf_counter -from typing import cast, Dict, List, Optional, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Tuple, Union + +import torch import torch.distributed as dist from torch import nn from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result +from torchrec.distributed.comm import get_local_size from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator -from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner +from torchrec.distributed.planner.partitioners import ( + GreedyPerfPartitioner, + MemoryBalancedPartitioner, +) from torchrec.distributed.planner.perf_models import NoopPerfModel from torchrec.distributed.planner.proposers import ( GreedyProposer, @@ -40,8 +50,14 @@ StorageReservation, Topology, ) -from torchrec.distributed.sharding_plan import placement +from torchrec.distributed.planner.utils import ( + bytes_to_gb, + reset_shard_rank, + storage_repr_in_gb, +) +from torchrec.distributed.sharding_plan import get_default_sharders, placement from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, EnumerableShardingSpec, ModuleSharder, ParameterSharding, @@ -50,9 +66,12 @@ ShardingType, ShardMetadata, ) +from torchrec.distributed.utils import none_throws + +logger: logging.Logger = logging.getLogger(__name__) -def _to_sharding_plan( +def to_sharding_plan( sharding_options: List[ShardingOption], topology: Topology, ) -> ShardingPlan: @@ -65,39 +84,95 @@ def _to_sharding_plan( shards = sharding_option.shards sharding_type = sharding_option.sharding_type - module_plan = plan.get(sharding_option.path, {}) + module_plan = plan.get(sharding_option.path, EmbeddingModuleShardingPlan()) module_plan[sharding_option.name] = ParameterSharding( - sharding_spec=None - if sharding_type == ShardingType.DATA_PARALLEL.value - else EnumerableShardingSpec( - [ - ShardMetadata( - shard_sizes=shard.size, - shard_offsets=shard.offset, - placement=placement( - compute_device, cast(int, shard.rank), local_size - ), - ) - for shard in shards - ] + sharding_spec=( + None + if sharding_type == ShardingType.DATA_PARALLEL.value + else EnumerableShardingSpec( + [ + ShardMetadata( + shard_sizes=shard.size, + shard_offsets=shard.offset, + placement=placement( + compute_device, cast(int, shard.rank), local_size + ), + ) + for shard in shards + ] + ) ), sharding_type=sharding_type, compute_kernel=sharding_option.compute_kernel, ranks=[cast(int, shard.rank) for shard in shards], + cache_params=sharding_option.cache_params, + enforce_hbm=sharding_option.enforce_hbm, + stochastic_rounding=sharding_option.stochastic_rounding, + bounds_check_mode=sharding_option.bounds_check_mode, + output_dtype=sharding_option.output_dtype, + key_value_params=sharding_option.key_value_params, ) plan[sharding_option.path] = module_plan return ShardingPlan(plan) +def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan: + if len(best_plans) == 1: + return best_plans[0] + else: + merged_plan = ShardingPlan({}) + for plan in best_plans: + for name, ps in plan.plan.items(): + ps = cast(EmbeddingModuleShardingPlan, ps) + if name not in merged_plan.plan: + merged_plan.plan[name] = ps + else: + for k, v in ps.items(): + cur_plan = cast( + EmbeddingModuleShardingPlan, merged_plan.plan[name] + ) + if k not in cur_plan: + cur_plan[k] = v + else: + raise PlannerError( + "table can not be sharded between two device group" + ) + + return merged_plan + + class EmbeddingShardingPlanner(ShardingPlanner): """ Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints. + + Args: + topology (Optional[Topology]): the topology of the current process group. + batch_size (Optional[int]): the batch size of the model. + enumerator (Optional[Enumerator]): the enumerator to use + storage_reservation (Optional[StorageReservation]): the storage reservation to use + proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use + partitioner (Optional[Partitioner]): the partitioner to use + performance_model (Optional[PerfModel]): the performance model to use + stats (Optional[Union[Stats, List[Stats]]]): the stats to use + constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints + for sharding. + debug (bool): whether to print debug information. + + Example:: + + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + planner = EmbeddingShardingPlanner() + plan = planner.plan( + module=ebc, + sharders=[EmbeddingBagCollectionSharder()], + ) + """ def __init__( self, - topology: Topology, + topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, @@ -107,8 +182,18 @@ def __init__( stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True, + callbacks: Optional[ + List[Callable[[List[ShardingOption]], List[ShardingOption]]] + ] = None, + timeout_seconds: Optional[int] = None, ) -> None: - self._topology = topology + if topology is None: + topology = Topology( + local_world_size=get_local_size(), + world_size=dist.get_world_size(), + compute_device="cuda" if torch.cuda.is_available() else "cpu", + ) + self._topology: Topology = topology self._batch_size: int = batch_size if batch_size else BATCH_SIZE self._constraints = constraints self._enumerator: Enumerator = ( @@ -143,7 +228,7 @@ def __init__( performance_model if performance_model else NoopPerfModel(topology=topology) ) - if stats: + if stats is not None: self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats else: self._stats = [EmbeddingStats()] @@ -152,16 +237,40 @@ def __init__( self._num_proposals: int = 0 self._num_plans: int = 0 self._best_plan: Optional[List[ShardingOption]] = None + self._callbacks: List[ + Callable[[List[ShardingOption]], List[ShardingOption]] + ] = ([] if callbacks is None else callbacks) + if timeout_seconds is not None: + assert timeout_seconds > 0, "Timeout must be positive" + self._timeout_seconds = timeout_seconds def collective_plan( self, module: nn.Module, - sharders: List[ModuleSharder[nn.Module]], - pg: dist.ProcessGroup, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + pg: Optional[dist.ProcessGroup] = None, ) -> ShardingPlan: """ Call self.plan(...) on rank 0 and broadcast + + Args: + module (nn.Module): the module to shard. + sharders (Optional[List[ModuleSharder[nn.Module]]]): the sharders to use for sharding + pg (Optional[dist.ProcessGroup]): the process group to use for collective operations + + Returns: + ShardingPlan: the sharding plan for the module. """ + if pg is None: + assert dist.is_initialized(), ( + "The default process group is not yet initialized. " + "Please call torch.distributed.init_process_group() first before invoking this. " + "If you are not within a distributed environment, use the single rank version plan() instead." + ) + pg = none_throws(dist.GroupMember.WORLD) + + if sharders is None: + sharders = get_default_sharders() return invoke_on_rank_and_broadcast_result( pg, 0, @@ -175,12 +284,24 @@ def plan( module: nn.Module, sharders: List[ModuleSharder[nn.Module]], ) -> ShardingPlan: + """ + Provides an optimized sharding plan for a given module with shardable parameters + according to the provided sharders, topology, and constraints. + Args: + module (nn.Module): the module to shard. + sharders (List[ModuleSharder[nn.Module]]): the sharders to use for sharding. + + Returns: + ShardingPlan: the sharding plan for the module. + """ self._num_proposals = 0 self._num_plans = 0 start_time = perf_counter() best_plan = None lowest_storage = Storage(MAX_SIZE, MAX_SIZE) + last_planner_error: Optional[PlannerError] = None + last_proposal: List[ShardingOption] = [] best_perf_rating = MAX_SIZE storage_constraint: Topology = self._storage_reservation.reserve( @@ -205,12 +326,21 @@ def plan( ] = {} for proposer in self._proposers: - proposer.load(search_space=search_space) + proposer.load(search_space=search_space, enumerator=self._enumerator) + start = time.time() for proposer in self._proposers: proposal = proposer.propose() while proposal: + end = time.time() + elapsed = end - start + if self._timeout_seconds: + if elapsed > self._timeout_seconds: + logger.info( + f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" + ) + break proposal_key = tuple(sorted(map(hash, proposal))) if proposal_key in proposal_cache: partitionable, plan, perf_rating = proposal_cache[proposal_key] @@ -218,26 +348,34 @@ def plan( partitionable=partitionable, plan=plan, perf_rating=perf_rating, + storage_constraint=storage_constraint, ) proposal = proposer.propose() continue self._num_proposals += 1 try: + # plan is just proposal where shard.rank is populated plan = self._partitioner.partition( - proposal=copy.deepcopy(proposal), + proposal=proposal, storage_constraint=storage_constraint, ) self._num_plans += 1 perf_rating = self._perf_model.rate(plan=plan) if perf_rating < best_perf_rating: best_perf_rating = perf_rating - best_plan = plan + best_plan = copy.deepcopy(plan) proposal_cache[proposal_key] = (True, plan, perf_rating) proposer.feedback( - partitionable=True, plan=plan, perf_rating=perf_rating + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, ) - except PlannerError: + except PlannerError as planner_error: + last_planner_error = planner_error + # shallow copy of the proposal + last_proposal: List[ShardingOption] = copy.copy(proposal) current_storage = cast( Storage, reduce( @@ -251,14 +389,23 @@ def plan( ) if current_storage < lowest_storage: lowest_storage = current_storage - proposal_cache[proposal_key] = (False, None, None) - proposer.feedback(partitionable=False) + proposal_cache[proposal_key] = (False, proposal, None) + proposer.feedback( + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, + ) + # clear shard.rank for each sharding_option + reset_shard_rank(proposal) proposal = proposer.propose() if best_plan: + for callback in self._callbacks: + best_plan = callback(best_plan) + self._best_plan = best_plan - sharding_plan = _to_sharding_plan(best_plan, self._topology) + sharding_plan = to_sharding_plan(best_plan, self._topology) end_time = perf_counter() for stats in self._stats: @@ -272,6 +419,7 @@ def plan( run_time=end_time - start_time, best_plan=best_plan, constraints=self._constraints, + sharders=sharders, debug=self._debug, ) return sharding_plan @@ -284,26 +432,422 @@ def plan( lambda x, y: x + y, [device.storage for device in storage_constraint.devices], ) + storage_reservation_solution = ( + ( + f"\n\t Storage reservation percentage: {self._storage_reservation._percentage}, " + f"\n\t Per rank reservation for dense storage: {storage_repr_in_gb(self._storage_reservation._dense_storage)}, " + f"\n\t Per rank reservation for kjt storage: {storage_repr_in_gb(self._storage_reservation._kjt_storage)}, " # pyre-ignore[16] + ) + if isinstance(self._storage_reservation, HeuristicalStorageReservation) + else f"\n\t Storage reservation percentage: {self._storage_reservation._percentage}, " # pyre-ignore[16] + ) no_plan_solution = ( f"Planner evaluated {self._num_proposals} proposals." "\nPossible solutions:" f"\n 1) Increase the number of devices ({self._topology.world_size})" f"\n 2) Reduce the model size (" - f"\n\t Global storage: {global_storage_capacity.hbm}, " - f"\n\t Available for model parallel: {global_storage_constraints}," - f"\n\t Requirement for model parallel: {lowest_storage})" + f"\n\t Global storage: {round(bytes_to_gb(global_storage_capacity.hbm), 3)} GB, " + f"\n\t Per rank hardware memory: {storage_repr_in_gb(self._topology.devices[0].storage)}, " + f"{storage_reservation_solution}" + f"\n\t Global storage available for model parallel: {storage_repr_in_gb(global_storage_constraints)}, " + f"\n\t Global storage requirement for model parallel: {storage_repr_in_gb(lowest_storage)})" f"\n 3) Reduce local batch size ({self._batch_size})" "\n 4) Remove planner constraints that might be reducing search space or available storage\n" ) + last_planner_error_info = f"Last planner error: \n\t{last_planner_error}\n" + + # printout stats for no plan situation + end_time = perf_counter() + sharding_plan = ShardingPlan(plan={}) + # force all shards to have rank= -1 + for sharding_option in last_proposal: + for shard in sharding_option.shards: + shard.rank = -1 + + for stats in self._stats: + stats.log( + sharding_plan=sharding_plan, + topology=self._topology, + batch_size=self._batch_size, + storage_reservation=self._storage_reservation, + num_proposals=self._num_proposals, + num_plans=self._num_plans, + run_time=end_time - start_time, + best_plan=last_proposal, + constraints=self._constraints, + sharders=sharders, + debug=self._debug, + ) + if not lowest_storage.fits_in(global_storage_constraints): raise PlannerError( error_type=PlannerErrorType.INSUFFICIENT_STORAGE, message="Unable to find a plan for this model because of insufficient storage. \n" - + no_plan_solution, + + no_plan_solution + + last_planner_error_info, ) else: raise PlannerError( error_type=PlannerErrorType.STRICT_CONSTRAINTS, message="Unable to find a plan for this model because of the strict constraints. \n" - + no_plan_solution, + + no_plan_solution + + last_planner_error_info, + ) + + +class HeteroEmbeddingShardingPlanner(ShardingPlanner): + """ + Provides an optimized sharding plan for a given module with shardable parameters + according to the provided sharders, topology, and constraints. + """ + + def __init__( + self, + topology_groups: Optional[Dict[str, Topology]] = None, + batch_size: Optional[int] = None, + enumerators: Optional[Dict[str, Enumerator]] = None, + storage_reservations: Optional[Dict[str, StorageReservation]] = None, + proposers: Optional[Dict[str, Union[Proposer, List[Proposer]]]] = None, + partitioners: Optional[Dict[str, Partitioner]] = None, + performance_models: Optional[Dict[str, PerfModel]] = None, + stats: Optional[Dict[str, Union[Stats, List[Stats]]]] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + debug: bool = True, + callbacks: Optional[ + List[Callable[[List[ShardingOption]], List[ShardingOption]]] + ] = None, + ) -> None: + default_device = "cuda" if torch.cuda.is_available() else "cpu" + if topology_groups is None: + topology_groups = { + default_device: Topology( + local_world_size=get_local_size(), + world_size=dist.get_world_size(), + compute_device=default_device, + ) + } + self._topology_groups: Dict[str, Topology] = topology_groups + self._batch_size: int = batch_size if batch_size else BATCH_SIZE + self._constraints = constraints + # pyre-ignore + self._enumerators: Dict[str, Enumerator] = ( + enumerators + if enumerators + else { + group: EmbeddingEnumerator( + topology=self._topology_groups[group], + batch_size=self._batch_size, + constraints=constraints, + use_exact_enumerate_order=True, + ) + for group in self._topology_groups.keys() + } + ) + # pyre-ignore + self._storage_reservations: Dict[str, StorageReservation] = ( + storage_reservations + if storage_reservations + else { + group: HeuristicalStorageReservation(percentage=0.15) + for group in self._topology_groups.keys() + } + ) + + # pyre-ignore + self._partitioners: Dict[str, Partitioner] = ( + partitioners + if partitioners + else { + group: MemoryBalancedPartitioner() + for group in self._topology_groups.keys() + } + ) + + if proposers: + # pyre-ignore + self._proposers: Dict[str, List[Proposer]] = proposers + else: + # pyre-ignore + self._proposers = { + group: [ + GridSearchProposer(), + GreedyProposer(), + GreedyProposer(use_depth=False), + UniformProposer(), + ] + for group in self._topology_groups.keys() + } + + # pyre-ignore + self._perf_models: Dict[str, PerfModel] = ( + performance_models + if performance_models + else { + group: NoopPerfModel(topology=self._topology_groups[group]) + for group in self._topology_groups + } + ) + + self._stats: Dict[str, List[Stats]] = {} + + if stats is not None: + # pyre-ignore [8] + self._stats = stats + else: + # pyre-ignore [8] + self._stats = { + group: [EmbeddingStats()] for group in self._topology_groups.keys() + } + + self._debug = debug + self._num_proposals: int = 0 + self._num_plans: int = 0 + self._best_plan: Optional[List[ShardingOption]] = None + self._callbacks: List[ + Callable[[List[ShardingOption]], List[ShardingOption]] + ] = ([] if callbacks is None else callbacks) + + def collective_plan( + self, + module: nn.Module, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + pg: Optional[dist.ProcessGroup] = dist.GroupMember.WORLD, + ) -> ShardingPlan: + """ + Call self.plan(...) on rank 0 and broadcast + """ + if pg is None: + assert dist.is_initialized(), ( + "The default process group is not yet initialized. " + "Please call torch.distributed.init_process_group() first before invoking this. " + "If you are not within a distributed environment, use the single rank version plan() instead." + ) + pg = none_throws(dist.GroupMember.WORLD) + assert len(self._topology_groups) == 1, "Only single topology is supported" + + if sharders is None: + sharders = get_default_sharders() + return invoke_on_rank_and_broadcast_result( + pg, + 0, + self.plan, + module, + sharders, + ) + + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + best_plans: List[ShardingPlan] = [] + for group, topology in self._topology_groups.items(): + self._num_proposals = 0 + self._num_plans = 0 + start_time = perf_counter() + best_plan = None + lowest_storage = Storage(MAX_SIZE, MAX_SIZE) + last_planner_error: Optional[PlannerError] = None + last_proposal: List[ShardingOption] = [] + best_perf_rating = MAX_SIZE + + storage_constraint: Topology = self._storage_reservations[group].reserve( + topology=topology, + batch_size=self._batch_size, + module=module, + sharders=sharders, + constraints=self._constraints, + ) + + search_space = self._enumerators[group].enumerate( + module=module, + sharders=sharders, + ) + + # filter by device group + search_space = [ + s_o + for s_o in search_space + # pyre-ignore [16] + if self._constraints[s_o.name].device_group == group + ] + + if not search_space: + # No shardable parameters + continue + + proposal_cache: Dict[ + Tuple[int, ...], + Tuple[bool, Optional[List[ShardingOption]], Optional[float]], + ] = {} + + for proposer in self._proposers[group]: + proposer.load( + search_space=search_space, enumerator=self._enumerators[group] ) + + for proposer in self._proposers[group]: + proposal = proposer.propose() + + while proposal: + proposal_key = tuple(sorted(map(hash, proposal))) + if proposal_key in proposal_cache: + partitionable, plan, perf_rating = proposal_cache[proposal_key] + proposer.feedback( + partitionable=partitionable, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + continue + + self._num_proposals += 1 + try: + # plan is just proposal where shard.rank is populated + plan = self._partitioners[group].partition( + proposal=proposal, + storage_constraint=storage_constraint, + ) + self._num_plans += 1 + perf_rating = self._perf_models[group].rate(plan=plan) + if perf_rating < best_perf_rating: + best_perf_rating = perf_rating + best_plan = copy.deepcopy(plan) + proposal_cache[proposal_key] = (True, plan, perf_rating) + proposer.feedback( + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + except PlannerError as planner_error: + last_planner_error = planner_error + # shallow copy of the proposal + last_proposal: List[ShardingOption] = copy.copy(proposal) + current_storage = cast( + Storage, + reduce( + lambda x, y: x + y, + [ + shard.storage + for option in proposal + for shard in option.shards + ], + ), + ) + if current_storage < lowest_storage: + lowest_storage = current_storage + proposal_cache[proposal_key] = (False, proposal, None) + proposer.feedback( + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, + ) + + # clear shard.rank for each sharding_option + reset_shard_rank(proposal) + proposal = proposer.propose() + + if best_plan: + for callback in self._callbacks: + best_plan = callback(best_plan) + + self._best_plan = best_plan + sharding_plan = to_sharding_plan( + best_plan, self._topology_groups[group] + ) + best_plans.append(sharding_plan) + + end_time = perf_counter() + for stats in self._stats[group]: + stats.log( + sharding_plan=sharding_plan, + topology=self._topology_groups[group], + batch_size=self._batch_size, + storage_reservation=self._storage_reservations[group], + num_proposals=self._num_proposals, + num_plans=self._num_plans, + run_time=end_time - start_time, + best_plan=best_plan, + constraints=self._constraints, + sharders=sharders, + debug=self._debug, + ) + else: + global_storage_capacity = reduce( + lambda x, y: x + y, + [device.storage for device in self._topology_groups[group].devices], + ) + global_storage_constraints = reduce( + lambda x, y: x + y, + [device.storage for device in storage_constraint.devices], + ) + storage_reservation_solution = ( + ( + # pyre-ignore [16] + f"\n\t Storage reservation percentage: {self._storage_reservations[group]._percentage}, " + f"\n\t Per rank reservation for dense storage: {storage_repr_in_gb(self._storage_reservations[group]._dense_storage)}, " + f"\n\t Per rank reservation for kjt storage: {storage_repr_in_gb(self._storage_reservations[group]._kjt_storage)}, " # pyre-ignore[16] + ) + if isinstance( + self._storage_reservations[group], HeuristicalStorageReservation + ) + else f"\n\t Storage reservation percentage: {self._storage_reservations[group]._percentage}, " + ) + no_plan_solution = ( + f"Planner evaluated {self._num_proposals} proposals for device group {group}." + "\nPossible solutions:" + f"\n 1) Increase the number of devices ({self._topology_groups[group].world_size})" + f"\n 2) Reduce the model size (" + f"\n\t Global storage: {round(bytes_to_gb(global_storage_capacity.hbm), 3)} GB, " + f"\n\t Per rank hardware memory: {storage_repr_in_gb(self._topology_groups[group].devices[0].storage)}, " + f"{storage_reservation_solution}" + f"\n\t Global storage available for model parallel: {storage_repr_in_gb(global_storage_constraints)}, " + f"\n\t Global storage requirement for model parallel: {storage_repr_in_gb(lowest_storage)})" + f"\n 3) Reduce local batch size ({self._batch_size})" + "\n 4) Remove planner constraints that might be reducing search space or available storage\n" + ) + last_planner_error_info = ( + f"Last planner error: \n\t{last_planner_error}\n" + ) + + # printout stats for no plan situation + end_time = perf_counter() + sharding_plan = ShardingPlan(plan={}) + # force all shards to have rank= -1 + for sharding_option in last_proposal: + for shard in sharding_option.shards: + shard.rank = -1 + + for stats in self._stats[group]: + stats.log( + sharding_plan=sharding_plan, + topology=self._topology_groups[group], + batch_size=self._batch_size, + storage_reservation=self._storage_reservations[group], + num_proposals=self._num_proposals, + num_plans=self._num_plans, + run_time=end_time - start_time, + best_plan=last_proposal, + constraints=self._constraints, + sharders=sharders, + debug=self._debug, + ) + + if not lowest_storage.fits_in(global_storage_constraints): + raise PlannerError( + error_type=PlannerErrorType.INSUFFICIENT_STORAGE, + message=f"Unable to find a plan for this model in device_group {group} because of insufficient storage. \n" + + no_plan_solution + + last_planner_error_info, + ) + else: + raise PlannerError( + error_type=PlannerErrorType.STRICT_CONSTRAINTS, + message=f"Unable to find a plan for this model in device_group {group} because of the strict constraints. \n" + + no_plan_solution + + last_planner_error_info, + ) + + return _merge_plans(best_plans) diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index 177f16b87..69166b871 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -5,13 +5,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import copy import itertools import logging +from collections import OrderedDict from decimal import Decimal -from typing import cast, Dict, List, Optional, Set, Tuple +from typing import Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union + +import torch + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner.types import Proposer, ShardingOption -from torchrec.distributed.planner.utils import prod +from torchrec.distributed.planner.types import ( + Enumerator, + Perf, + Proposer, + ShardingOption, + Topology, +) +from torchrec.distributed.planner.utils import bytes_to_gb, LuusJaakolaSearch, prod +from torchrec.distributed.types import CacheAlgorithm logger: logging.Logger = logging.getLogger(__name__) @@ -28,7 +43,8 @@ class GreedyProposer(Proposer): Args: use_depth (bool): When enabled, sharding_options of a fqn are sorted based on - `max(shard.perf)`, otherwise sharding_options are sorted by `sum(shard.perf)`. + `max(shard.perf.total)`, otherwise sharding_options are sorted by + `sum(shard.perf.total)`. threshold (Optional[int]): Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating. @@ -42,7 +58,11 @@ def __init__(self, use_depth: bool = True, threshold: Optional[int] = None) -> N self._best_perf_rating: float = float("inf") self._num_inferior_perf: int = 0 - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() for sharding_option in search_space: fqn = sharding_option.fqn @@ -77,6 +97,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # When threshold is passed, observe the perf_rating trend. If the perf_rating # of the newly proposed plans have worse perf_rating, stop proposing. @@ -125,12 +146,16 @@ def __init__(self, use_depth: bool = True) -> None: self._grouped_sharding_options: List[List[ShardingOption]] = [] self._proposal_index: int = 0 - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() all_fqns = set() - sharding_options_by_type_and_fqn: Dict[ - str, Dict[str, List[ShardingOption]] - ] = {} + sharding_options_by_type_and_fqn: Dict[str, Dict[str, List[ShardingOption]]] = ( + {} + ) for sharding_option in search_space: sharding_type = sharding_option.sharding_type @@ -174,6 +199,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # static strategy, ignore feedback and just provide next proposal self._proposal_index += 1 @@ -186,7 +212,11 @@ def __init__(self, max_proposals: int = MAX_PROPOSALS) -> None: self._proposal_index: int = 0 self._proposals: List[List[int]] = [] - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() for sharding_option in search_space: fqn = sharding_option.fqn @@ -219,6 +249,8 @@ def load(self, search_space: List[ShardingOption]) -> None: range(len(sharding_options)) for sharding_options in self._sharding_options_by_fqn.values() ] + # pyre-fixme[8]: Attribute has type `List[List[int]]`; used as + # `List[Tuple[int]]`. self._proposals = list(itertools.product(*sharding_options_by_fqn_indices)) def _reset(self) -> None: @@ -243,18 +275,679 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # static strategy, ignore feedback and just provide next proposal self._proposal_index += 1 +def _bytes_to_float_bin(num_bytes: Union[float, int], bin_size: float) -> float: + return float(num_bytes) / bin_size + + +class DynamicProgrammingProposer(Proposer): + r"""Proposes sharding plans in dynamic programming fashion. + + The problem of the Embedding Sharding Plan can be framed as follows: Given + :math:`M` tables and their corresponding :math:`N` Sharding Options, we need to + select one sharding option for each table such that the total performance is + minimized, while keeping the overall HBM constraint :math:`K` in check. This can + be abstracted into the following mathematical formulation: + + Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B` + of the same dimensions, let the elements of matrix :math:`A` be denoted as + :math:`a_{i,j}` and the elements of matrix :math:`B` as :math:`b_{i,j}`. We aim + to find a set of column indices :math:`\{ j_0, j_1, \ldots, j_{M-1} \}` such that + the following conditions are satisfied: + + 1. :math:`\sum_{i=0}^{M-1} a_{i,j_i} \leq K`, where :math:`K` is a float. + 2. :math:`\sum_{i=0}^{M-1} b_{i,j_i}` is minimized. + + This problem can be tackled using dynamic programming. First, discretize :math:`K` + into :math:`K_i`, and denote the discretization function as :math:`f`. + + Define the state :math:`dp[i][f(k)]` to represent the minimum value of :math:`B` + when considering the first :math:`i` rows and the total sum of :math:`A` is equal to + the discretized value :math:`k`. + + The state transition can then be represented as: + + .. math:: + dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right) + + Since :math:`K` is the sum allocated across all HBM, simply satisfying that the + total HBM in the plan equals :math:`K` does not guarantee that the allocation will + fit on all cards. Therefore, it is essential to maintain all the states of the last + layer of :math:`dp`. This allows us to propose different plans under varying total + HBM constraints. + + Args: + hbm_bins_per_device (int): hdm bins for dynamic programming precision. + """ + + def __init__(self, hbm_bins_per_device: int = 100) -> None: + self._inited: bool = False + self._hbm_bins_per_device: int = max(hbm_bins_per_device, 1) + self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = ( + OrderedDict() + ) + # list of proposals with different total_hbm, a proposal is a list of indices of sharding_options + self._proposal_list: List[List[int]] = [] + self._current_proposal: int = -1 + + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: + """Load search space.""" + self._reset() + # order the sharding_option by total_storage.hbm from low to high + for sharding_option in sorted(search_space, key=lambda x: x.total_storage.hbm): + fqn = sharding_option.fqn + if fqn not in self._sharding_options_by_fqn: + self._sharding_options_by_fqn[fqn] = [] + self._sharding_options_by_fqn[fqn].append(sharding_option) + + def _reset(self) -> None: + self._sharding_options_by_fqn = OrderedDict() + self._proposal_list = [] + self._current_proposal = -1 + + def propose(self) -> Optional[List[ShardingOption]]: + """Propose a sharding plan.""" + if not self._inited: + return [ + sharding_options[0] + for sharding_options in self._sharding_options_by_fqn.values() + ] + elif self._current_proposal >= 0: + proposal_index = self._proposal_list[self._current_proposal] + return [ + self._sharding_options_by_fqn[fqn][index] + for fqn, index in zip( + self._sharding_options_by_fqn.keys(), proposal_index + ) + ] + else: + return None + + def feedback( + self, + partitionable: bool, + plan: Optional[List[ShardingOption]] = None, + perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, + ) -> None: + """Feedback last proposed plan.""" + if not self._inited: + self._inited = True + table_count = len(self._sharding_options_by_fqn) + option_count = max([len(x) for x in self._sharding_options_by_fqn.values()]) + + assert storage_constraint is not None + # are we assuming the table will be evenly sharded on all devices? + hbm_total = sum([x.storage.hbm for x in storage_constraint.devices]) + bin_count = self._hbm_bins_per_device * len(storage_constraint.devices) + bin_size = float(hbm_total) / bin_count + + dp = [ + [(float("inf"), float("inf"))] * bin_count for _ in range(table_count) + ] # [table_id][hbm_bin][perf, hbm] + + backtrack = [ + [(-1, -1)] * bin_count for _ in range(table_count) + ] # [table_id][hbm_bin][opt_id, prev_hbm_bin] + + hbm_by_fqn = [ + [float("inf") for _ in range(option_count)] for _ in range(table_count) + ] # memory constraint lookup table: [table_id][sharding_option_id] + perf_by_fqn = [ + [float("inf") for _ in range(option_count)] for _ in range(table_count) + ] # performance metrics lookup table: [table_id][sharding_option_id] + + # populate hbm and perf for each sharding option and table: A[table_id][sharding_option_id] + for table_id, sharding_options in enumerate( + self._sharding_options_by_fqn.values() + ): + for opt_id, sharding_option in enumerate(sharding_options): + hbm_by_fqn[table_id][opt_id] = _bytes_to_float_bin( + sharding_option.total_storage.hbm, bin_size + ) + perf_by_fqn[table_id][opt_id] = sharding_option.total_perf + + table_0 = 0 + for opt_j in range(option_count): + if hbm_by_fqn[0][opt_j] < bin_count: + hbm_i = int(hbm_by_fqn[0][opt_j]) + # options are ordered in increasing order of hbm, we only want to consider + # a sharding option that has higher hbm and better perf (the smaller the better) + if dp[table_0][hbm_i][0] > perf_by_fqn[table_0][opt_j]: + dp[table_0][hbm_i] = ( + perf_by_fqn[table_0][opt_j], + hbm_by_fqn[table_0][opt_j], + ) + backtrack[table_0][hbm_i] = (opt_j, -1) + + # dp: table_count x option_count x bin_count + for table_i in range(1, table_count): + for opt_j in range(option_count): + for hbm in range(bin_count): + prev_perf, perv_hbm = dp[table_i - 1][hbm] + if prev_perf < float("inf"): + new_hbm = perv_hbm + hbm_by_fqn[table_i][opt_j] + if new_hbm < bin_count: + new_hbm_i = int(new_hbm) + new_perf = prev_perf + perf_by_fqn[table_i][opt_j] + if dp[table_i][new_hbm_i][0] > new_perf: + dp[table_i][new_hbm_i] = (new_perf, new_hbm) + backtrack[table_i][new_hbm_i] = (opt_j, hbm) + self._proposal_list = [] + # fill in all the proposals, starting from highest hbm to lowest hbm + for c in range(bin_count - 1, -1, -1): + cur_opt_idx, cur_hbm_idx = backtrack[table_count - 1][c] + if cur_opt_idx >= 0: + proposal_indices = [-1] * table_count + proposal_indices[table_count - 1] = cur_opt_idx + for i in range(table_count - 2, -1, -1): + proposal_indices[i], cur_hbm_idx = backtrack[i][cur_hbm_idx] + self._proposal_list.append(proposal_indices) + if len(self._proposal_list) > 0: + self._current_proposal = 0 + else: + self._current_proposal += 1 + if self._current_proposal >= len(self._proposal_list): + self._current_proposal = -1 + + +_T = TypeVar("_T") + + +def _none_throws(x: Optional[_T]) -> _T: + if x is None: + raise AssertionError("unexpected None") + return x + + +class EmbeddingOffloadScaleupProposer(Proposer): + def __init__(self, use_depth: bool = True) -> None: + self.use_depth: bool = use_depth + self.enumerator: Optional[Enumerator] = None + self.starting_proposal: List[ShardingOption] = [] + self.proposal: Optional[List[ShardingOption]] = None + self.search: Optional[LuusJaakolaSearch] = None + self.best_perf_rating: float = 1e99 + + def _build_proposal_from_sharding_options( + self, + sharding_options_by_fqn: Dict[str, List[ShardingOption]], + ) -> List[ShardingOption]: + """ + Given a list of sharding options for each embedding table, selects which to include in the proposal. + """ + # TODO(T206831945): Currently only uses 1 sharding option for proposal. + # There is room for potentially exploring multiple options if done + # carefully, e.g. traversing like GreedyProposer. + proposal: List[ShardingOption] = [] + for table_sharding_options in sharding_options_by_fqn.values(): + if len(table_sharding_options) > 1: + logger.warning( + f"EmbeddingOffloadScaleupProposer - ignored {len(table_sharding_options) - 1} sharding options for table {table_sharding_options[0].name} in proposal" + ) + + selected_option = next( + ( + sharding_option + for sharding_option in table_sharding_options + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ), + table_sharding_options[0], + ) + proposal.append(selected_option) + + # Miss-ratio curves used for stats are modeled for an LRU cache, LFU cache would not work well with ScaleupProposer. + if ( + selected_option.cache_params is not None + and selected_option.cache_params.algorithm is not None + and selected_option.cache_params.algorithm != CacheAlgorithm.LRU + ): + logger.error( + f"EmbeddingOffloadScaleupProposer - proposer only supports LRU cache algorithm, but {selected_option.cache_params.algorithm} is used for {selected_option}" + ) + + return proposal + + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: + self.enumerator = enumerator + sharding_options_by_fqn: Dict[str, List[ShardingOption]] = {} + for sharding_option in search_space: + sharding_options_by_fqn.setdefault(sharding_option.fqn, []).append( + sharding_option + ) + for sharding_options in sharding_options_by_fqn.values(): + sharding_options.sort( + key=lambda x: _sharding_option_score(x, self.use_depth) + ) + + proposal: List[ShardingOption] = self._build_proposal_from_sharding_options( + sharding_options_by_fqn + ) + + # deepcopy so it won't affect other proposers + self.starting_proposal = copy.deepcopy(proposal) + self.promote_high_prefetch_overheaad_table_to_hbm( + self.enumerator, self.starting_proposal + ) + self.proposal = copy.deepcopy(self.starting_proposal) + + @staticmethod + def get_hbm_ceiling( + starting_proposal: List[ShardingOption], enumerator: Enumerator + ) -> int: + """returns total amount of memory scaleup could use.""" + proposal = copy.deepcopy(starting_proposal) + cache_tables = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options( + proposal + ) + for sharding_option in cache_tables: + if ( + sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ): + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + enumerator.populate_estimates(cache_tables) + return sum(sharding_option.total_storage.hbm for sharding_option in proposal) + + @staticmethod + def promote_high_prefetch_overheaad_table_to_hbm( + enumerator: Optional[Enumerator], proposal: List[ShardingOption] + ) -> None: + """ + Prefetch overhead is related to IO. When it's larger than saved memory from + embedding offloading, we'd undo offloading and promote to HBM for better + memory efficiency. + + This function will end up updating proposal. + """ + if not enumerator: + return + what_if_hbm_proposal = copy.deepcopy(proposal) + what_if_hbm_cached_tables = [ + sharding_option + for sharding_option in what_if_hbm_proposal + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ] + original_cached_tables = [ + sharding_option + for sharding_option in proposal + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ] + + # Modify all cached tables in what_if_proposal to be HBM only + for sharding_option in what_if_hbm_cached_tables: + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + + # appease pyre + assert enumerator + enumerator.populate_estimates(what_if_hbm_cached_tables) + + # Now what_if_hbm_proposal contain estimated storage for all HBM case. If + # it's even smaller than offloaded case, we promote it to HBM + promoted_count = 0 + saved_hbm = 0 + for so, original_so in zip(what_if_hbm_cached_tables, original_cached_tables): + if so.total_storage.hbm < original_so.total_storage.hbm: + promoted_count += 1 + saved_hbm += original_so.total_storage.hbm - so.total_storage.hbm + assert original_so.cache_params # appease pyre + original_so.cache_params.load_factor = None + original_so.compute_kernel = EmbeddingComputeKernel.FUSED.value + + if promoted_count > 0: + logger.info( + f"EmbeddingOffloadScaleupProposer - promoted {promoted_count} tables to HBM, because their IO size is larger than the table size itself, saving {saved_hbm // 1024 // 1024}MiB HBM" + ) + + # In the end, update the storage cost for new proposal + + # appease pyre + assert enumerator + enumerator.populate_estimates(original_cached_tables) + + def propose(self) -> Optional[List[ShardingOption]]: + return self.proposal + + def feedback( + self, + partitionable: bool, + plan: Optional[List[ShardingOption]] = None, + perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, + ) -> None: + if not self.enumerator or plan is None: + self.proposal = None + return + + hbm_used_previously = sum( + sharding_option.total_storage.hbm for sharding_option in plan + ) + + if self.search is None: + if not partitionable or storage_constraint is None: + self.proposal = None + return + + hbm_available = EmbeddingOffloadScaleupProposer.get_budget( + plan, storage_constraint + ) + # max scale up + peak_budget_need = ( + EmbeddingOffloadScaleupProposer.get_hbm_ceiling( + plan, _none_throws(self.enumerator) + ) + - hbm_used_previously + ) + search_budget = min(hbm_available, peak_budget_need) + + logger.info( + f"EmbeddingOffloadScaleupProposer - unscaled plan={round(bytes_to_gb(hbm_used_previously),2)} GB, cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, peak scale up budget need={round(bytes_to_gb(peak_budget_need),2)} GB, exploring plans of size [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + search_budget), 2)}] GB" + ) + self.search = LuusJaakolaSearch( + 0, search_budget, max_iterations=16, left_cost=perf_rating + ) + + best = False + if perf_rating is not None and perf_rating < self.best_perf_rating: + self.best_perf_rating = perf_rating + best = True + + logger.info( + f"EmbeddingOffloadScaleupProposer - proposed size={bytes_to_gb(hbm_used_previously):.2f} GB, score={perf_rating}{' BEST' if best else ''}" + ) + + if not partitionable: + # Focus our search on smaller plans by assuming plans larger than this + # proposal will also fail to partition. + starting_size = sum( + sharding_option.total_storage.hbm + for sharding_option in self.starting_proposal + ) + new_budget = hbm_used_previously - starting_size + self.search.shrink_right(new_budget) # pyre-ignore + + assert self.search is not None # keep pyre happy + budget = self.search.next(perf_rating or 1e99) + if budget is not None: + budget = int(budget) + self.proposal = EmbeddingOffloadScaleupProposer.next_plan( + self.starting_proposal, budget, self.enumerator + ) + + @staticmethod + def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int: + """returns additional HBM budget available for GPU caches.""" + available_hbm = sum(device.storage.hbm for device in storage_constraint.devices) + used_hbm = sum( + sharding_option.total_storage.hbm for sharding_option in proposal + ) + return available_hbm - used_hbm + + @staticmethod + def get_scalable_sharding_options( + proposal: List[ShardingOption], + ) -> List[ShardingOption]: + """Return the subset of tables that we can scale.""" + + def none_to_zero(x: Optional[float]) -> float: + return x if x is not None else 0.0 + + return [ + sharding_option + for sharding_option in proposal + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + and none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups(sharding_option) + ) + * none_to_zero(sharding_option.cache_load_factor) + > 0 + ] + + # Given an available budget of additional memory, and a provisional sharding plan, + # attempt to use the budget wisely to scale up caches that would most benefit from it. + @staticmethod + def next_plan( + starting_proposal: List[ShardingOption], + budget: Optional[int], + enumerator: Optional[Enumerator], + ) -> Optional[List[ShardingOption]]: + if budget is None or enumerator is None: + return None + + def none_to_zero(x: Optional[float]) -> float: + return x if x is not None else 0.0 + + proposal = copy.deepcopy(starting_proposal) + # This is the subset of tables that we can scale + cache_tables = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options( + proposal + ) + # Nothing to scale + if len(cache_tables) == 0: + return None + + size_model, fused_hbm_ceiling = ( + EmbeddingOffloadScaleupProposer.build_affine_storage_model( + cache_tables, enumerator + ) + ) + clfs = torch.tensor( + [sharding_option.cache_load_factor for sharding_option in cache_tables] + ) + # cooked_cacheability is cacheability scaled by the expected number of cache + # lookups. + + cooked_cacheability = torch.tensor( + [ + none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups( + sharding_option + ) + ) + for sharding_option in cache_tables + ] + ) + new_clfs = EmbeddingOffloadScaleupProposer.allocate_budget( + model=size_model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=clfs, + budget=budget, + allocation_priority=cooked_cacheability, + ) + + num_promoted = 0 + # apply new_clfs, promoting tables that made it to 1.0 + for sharding_option, clf in zip(cache_tables, new_clfs): + clf = clf.item() # tensor scalar -> scalar + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + if clf > 0.9999: # tolerate float roundoff + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + num_promoted += 1 + if num_promoted > 0: + logger.info( + f"EmbeddingOffloadScaleupProposer - Promoted {num_promoted} tables to HBM because cache size is similar to table size." + ) + # recalculate cost estimates of modified tables + enumerator.populate_estimates(cache_tables) + return proposal + + @staticmethod + def get_cacheability(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.cacheability + + @staticmethod + def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.expected_lookups + + # The relationship between clf and shard memory usage is non-linear due to non-clf + # overheads like optimization stats and input/output storage. We model it as an + # affine relationship: bytes = clf * A + B where B is fixed overhead independent of + # CLF (e.g. input / output IO sizes and A is per cache-row overhead. + @staticmethod + def build_affine_storage_model( + uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator + ) -> Tuple[torch.Tensor, torch.Tensor]: + plan: List[ShardingOption] = copy.deepcopy(uvm_caching_sharding_options) + + def set_clf(sharding_option: ShardingOption, clf: float) -> None: + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + + def set_fused(sharding_option: ShardingOption) -> None: + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + + def compute_hbm_sizes(f: Callable[[ShardingOption], None]) -> torch.Tensor: + for sharding_option in plan: + f(sharding_option) + enumerator.populate_estimates(plan) + return torch.tensor( + [sharding_option.total_storage.hbm for sharding_option in plan] + ) + + low_clf, high_clf = 0.1, 0.9 + low_hbms = compute_hbm_sizes(lambda so: set_clf(so, low_clf)) + high_hbms = compute_hbm_sizes(lambda so: set_clf(so, high_clf)) + fused_hbms = compute_hbm_sizes(set_fused) + + A = (high_hbms - low_hbms) / (high_clf - low_clf) + B = low_hbms - A * low_clf + caching_model = torch.stack((A, B), dim=1) # Nx2 (a,b) + return caching_model, fused_hbms + + @staticmethod + def clf_to_bytes( + model: torch.Tensor, clfs: Union[float, torch.Tensor] + ) -> torch.Tensor: + # evaluate affine model AX + B + return (model[:, 0] * clfs + model[:, 1]).to(torch.float64) + + # Given a model of an affine system, an existing configuration (clfs), available + # budget, and an allocation policy, return new configuration that best uses the + # available budget. We only add additional budget, we assume the existing + # configuration is specifying a floor or minimum size. + @staticmethod + def allocate_budget( + model: torch.Tensor, + fused_hbm_ceiling: torch.Tensor, + clfs: torch.Tensor, + budget: int, + allocation_priority: torch.Tensor, + ) -> torch.Tensor: + # min size is size of table at 0 CLF + min_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 0) + max_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1) + table_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, clfs) + cache_size_bytes = table_size_bytes - min_size_bytes + max_cache_size_bytes = max_size_bytes - min_size_bytes + + # We have budget bytes to share across the tables in. We want to increase the + # cache_size_bytes of each table in proportion to their allocation priority + # fraction. If we raise the cache_size_bytes to beyond max_cache_size_bytes, + # this is equivalent to reaching CLF=1.0, so we clip the memory to 1.0, and + # reassign the released budget in a subsequent pass. + num_pass = 0 + while budget > 1 and num_pass < 128: + num_pass += 1 + # mask is False for tables at >= max_size, and True otherwise. This allows + # us to remove tables that have already reached full size in one round from + # being dealt more budget in subsequent rounds. + mask = (min_size_bytes + cache_size_bytes) < max_size_bytes + if mask.sum() == 0: + break + + logger.debug( + f"[allocate_budget] pass={num_pass}, budget={budget}, #cache_tables={mask.sum()}" + ) + + # switch to 64bit float to avoid rounding errors, as table cache sizes can + # easily be > 2^24. + masked_priority = (mask * allocation_priority).to(torch.float64) + increase_ratio = masked_priority / torch.sum(masked_priority) + proposed_increase_bytes = budget * increase_ratio + new_cache_size_bytes = torch.minimum( + cache_size_bytes + proposed_increase_bytes, max_cache_size_bytes + ) + actual_increase_bytes = new_cache_size_bytes - cache_size_bytes + + budget -= torch.sum(actual_increase_bytes).item() # pyre-ignore[58] + cache_size_bytes = new_cache_size_bytes + # TODO: consider trade off of using remaining budget to push >0.95 tables + # to HBM vs spending that budget on improving hit rate on other tables in + # next pass. + + # Is any table over the size we'd get if we promoted to HBM? (promotion can + # be smaller if input size is large when using prefetch). If so, mark for + # promotion and reclaim budget to use on remaining tables. + promotes = mask & (min_size_bytes + cache_size_bytes > fused_hbm_ceiling) + if promotes.sum() > 0: + budget_reclaimed = torch.sum( + ((min_size_bytes + cache_size_bytes) - fused_hbm_ceiling)[promotes] + ).item() + logger.debug( + f"[allocate_budget] {promotes.sum()} tables exceeded ceiling, promoting to save {budget_reclaimed} bytes" + ) + budget += budget_reclaimed # pyre-ignore[58] + # force these tables to 1.0 to ensure promotion + cache_size_bytes[promotes] = max_cache_size_bytes[promotes] + + # cache_size_bytes are the new cache sizes we want to use. We convert them back + # to clfs by dividing by max_cache_size_bytes, which has isolated the clf + # portion of the table size from the fixed overheads. + # convert 64bit values back to original clf precision + return (cache_size_bytes / max_cache_size_bytes).to(clfs.dtype) + + def _sharding_option_score( sharding_option: ShardingOption, use_depth: bool = True ) -> float: return ( - max([cast(float, shard.perf) for shard in sharding_option.shards]) + max([cast(Perf, shard.perf).total for shard in sharding_option.shards]) if use_depth - else sum([cast(float, shard.perf) for shard in sharding_option.shards]) + else sum([cast(Perf, shard.perf).total for shard in sharding_option.shards]) ) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 0b559b1e1..ce6fbd6be 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -5,6 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import logging import math from typing import cast, Dict, List, Optional, Tuple, Type @@ -15,7 +18,6 @@ from torchrec.distributed.planner.constants import ( BATCHED_COPY_PERF_FACTOR, BIGINT_DTYPE, - BWD_COMPUTE_MULTIPLIER, DP_ELEMENTWISE_KERNELS_PERF_FACTOR, FULL_BLOCK_EMB_DIM, HALF_BLOCK_PENALTY, @@ -25,7 +27,10 @@ WEIGHTED_KERNEL_MULTIPLIER, ) from torchrec.distributed.planner.types import ( + CollectiveType, + GeneralizedCommsBandwidth, ParameterConstraints, + Perf, PlannerError, ShardEstimator, ShardingOption, @@ -33,14 +38,47 @@ Topology, ) from torchrec.distributed.planner.utils import prod, sharder_name -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import ( + CacheStatistics, + CommOp, + ModuleSharder, + PipelineType, + ShardingType, +) +from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface +logger: logging.Logger = logging.getLogger(__name__) + + +def _is_prefetch_pipelined( + sharding_option: ShardingOption, sharder: ModuleSharder[nn.Module] +) -> bool: + prefetch_pipeline = ( + sharding_option.cache_params.prefetch_pipeline + if sharding_option.cache_params + else None + ) + # TODO: remove after deprecating fused_params in sharder + if not prefetch_pipeline: + prefetch_pipeline = ( + sharder.fused_params.get("prefetch_pipeline", False) # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else False + ) + return prefetch_pipeline + class EmbeddingPerfEstimator(ShardEstimator): """ - Embedding Wall Time Perf Estimator + Embedding Wall Time Perf Estimator. This estimator estimates the wall time + of a given sharding option. + + Args: + topology (Topology): device topology. + constraints (Optional[Dict[str, ParameterConstraints]]): parameter constraints. + is_inference (bool): whether or not the estimator is used for inference. """ def __init__( @@ -58,6 +96,13 @@ def estimate( sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None, ) -> None: + """ + Estimates the wall time of a given sharding option. + + Args: + sharding_options (List[ShardingOption]): list of sharding options. + sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]): sharder map. + """ if not sharder_map: assert not sharding_options, "sharder_map not provided for sharding_options" return @@ -65,11 +110,16 @@ def estimate( for sharding_option in sharding_options: sharder_key = sharder_name(type(sharding_option.module[1])) sharder = sharder_map[sharder_key] - caching_ratio = ( - sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] - if hasattr(sharder, "fused_params") and sharder.fused_params - else None - ) + + caching_ratio = sharding_option.cache_load_factor + # TODO: remove after deprecating fused_params in sharder + if caching_ratio is None: + caching_ratio = ( + sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + num_poolings = ( cast(List[float], self._constraints[sharding_option.name].num_poolings) if self._constraints @@ -94,9 +144,23 @@ def estimate( module = sharding_option.module[1] # TODO remove this hack once feature processor is disaggregated - has_feature_processor = ( - True if getattr(module, "feature_processor", None) else False - ) + has_feature_processor = False + if ( + hasattr(module, "_feature_processor") + and hasattr(module._feature_processor, "feature_processor_modules") + and isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `feature_processor_modules`. + module._feature_processor.feature_processor_modules, + nn.ModuleDict, + ) + and sharding_option.name + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `feature_processor_modules`. + in module._feature_processor.feature_processor_modules.keys() + ): + has_feature_processor = True + logger.info(f"Table {sharding_option.name} has feature processor.") if isinstance(module, EmbeddingBagCollectionInterface): is_weighted = module.is_weighted() @@ -109,7 +173,47 @@ def estimate( else: is_weighted = False - shard_perfs = perf_func_emb_wall_time( + # TODO remove this once migrate away from PEA + is_weighted = is_weighted or has_feature_processor + sharding_option.is_weighted = is_weighted + + table_data_type_size = sharding_option.tensor.element_size() + ( + fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size, + fwd_sr_comm_data_type_size, + bwd_sr_comm_data_type_size, + ) = _extract_comm_data_type_size(sharder, sharding_option) + + prefetch_pipeline = _is_prefetch_pipelined(sharding_option, sharder) + + # hardcoded as 8 bytes + # input indices can be of int32, but in TBE they get converted to int64 anyway + input_data_type_size = BIGINT_DTYPE + output_data_type_size: float = ( + DATA_TYPE_NUM_BITS[sharding_option.output_dtype] / 8 + if sharding_option.output_dtype + else sharding_option.tensor.element_size() + ) + + expected_cache_fetches = 0 + if ( + caching_ratio is not None + and sharding_option.cache_params is not None + and sharding_option.cache_params.stats is not None + and sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ): + _stats = sharding_option.cache_params.stats + expected_cache_fetches = ( + _stats.expected_miss_rate(caching_ratio) * _stats.expected_lookups + ) + # Note, the above is not correct for data-parallel. stats.expected_lookups is + # calculated by estimating the cardinality of a global batch size worth of data. + # But for data-parallel, we need the calculate the cardinality of the local + # input batch. For now, we don't use cache stats with data parallel. + + shard_perfs = self.perf_func_emb_wall_time( shard_sizes=[shard.size for shard in sharding_option.shards], compute_kernel=sharding_option.compute_kernel, compute_device=self._topology.compute_device, @@ -118,462 +222,764 @@ def estimate( world_size=self._topology.world_size, local_world_size=self._topology.local_world_size, input_lengths=sharding_option.input_lengths, - input_data_type_size=BIGINT_DTYPE, - output_data_type_size=sharding_option.tensor.element_size(), + input_data_type_size=input_data_type_size, + table_data_type_size=table_data_type_size, + output_data_type_size=output_data_type_size, + fwd_a2a_comm_data_type_size=fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size=bwd_a2a_comm_data_type_size, + fwd_sr_comm_data_type_size=fwd_sr_comm_data_type_size, + bwd_sr_comm_data_type_size=bwd_sr_comm_data_type_size, num_poolings=num_poolings, - bw_intra_host=self._topology.intra_host_bw, - bw_inter_host=self._topology.inter_host_bw, + hbm_mem_bw=self._topology.hbm_mem_bw, + ddr_mem_bw=self._topology.ddr_mem_bw, + hbm_to_ddr_mem_bw=self._topology.hbm_to_ddr_mem_bw, + comms_bandwidths=self._topology.comms_bandwidths, + bwd_compute_multiplier=self._topology.bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=self._topology.weighted_feature_bwd_compute_multiplier, is_pooled=sharding_option.is_pooled, is_weighted=is_weighted, is_inference=self._is_inference, - has_feature_processor=has_feature_processor, caching_ratio=caching_ratio, + prefetch_pipeline=prefetch_pipeline, + expected_cache_fetches=expected_cache_fetches, + uneven_sharding_perf_multiplier=self._topology.uneven_sharding_perf_multiplier, ) for shard, perf in zip(sharding_option.shards, shard_perfs): shard.perf = perf + @classmethod + def perf_func_emb_wall_time( + cls, + shard_sizes: List[List[int]], + compute_kernel: str, + compute_device: str, + sharding_type: str, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + input_lengths: List[float], + input_data_type_size: float, + table_data_type_size: float, + output_data_type_size: float, + fwd_a2a_comm_data_type_size: float, + bwd_a2a_comm_data_type_size: float, + fwd_sr_comm_data_type_size: float, + bwd_sr_comm_data_type_size: float, + num_poolings: List[float], + hbm_mem_bw: float, + ddr_mem_bw: float, + hbm_to_ddr_mem_bw: float, + comms_bandwidths: GeneralizedCommsBandwidth, + bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, + is_pooled: bool, + is_weighted: bool = False, + caching_ratio: Optional[float] = None, + is_inference: bool = False, + prefetch_pipeline: bool = False, + expected_cache_fetches: float = 0, + uneven_sharding_perf_multiplier: float = 1.0, + ) -> List[Perf]: + """ + Attempts to model perfs as a function of relative wall times. + + Args: + shard_sizes (List[List[int]]): the list of (local_rows, local_cols) of each + shard. + compute_kernel (str): compute kernel. + compute_device (str): compute device. + sharding_type (str): tw, rw, cw, twrw, dp. + batch_sizes (List[int]): batch size for each input feature. + world_size (int): the number of devices for all hosts. + local_world_size (int): the number of the device for each host. + input_lengths (List[float]): the list of the average number of lookups of each + input query feature. + input_data_type_size (float): the data type size of the distributed + data_parallel input. + table_data_type_size (float): the data type size of the table. + output_data_type_size (float): the data type size of the output embeddings. + fwd_comm_data_type_size (float): the data type size of the distributed + data_parallel input during forward communication. + bwd_comm_data_type_size (float): the data type size of the distributed + data_parallel input during backward communication. + num_poolings (List[float]): number of poolings per sample, typically 1.0. + hbm_mem_bw (float): the bandwidth of the device HBM. + ddr_mem_bw (float): the bandwidth of the system DDR memory. + hbm_to_ddr_bw (float): the bandwidth between device HBM and system DDR. + intra_host_bw (float): the bandwidth within a single host like multiple threads. + inter_host_bw (float): the bandwidth between two hosts like multiple machines. + is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), False + if unpooled/sequential (ie. `Embedding`). + is_weighted (bool = False): if the module is an EBC and is weighted, typically + signifying an id score list feature. + is_inference (bool = False): if planning for inference. + caching_ratio (Optional[float] = None): cache ratio to determine the bandwidth + of device. + prefetch_pipeline (bool = False): whether prefetch pipeline is enabled. + expected_cache_fetches (float): number of expected cache fetches across global batch + uneven_sharding_perf_multiplier (float = 1.0): multiplier to account for uneven sharding perf + + Returns: + List[float]: the list of perf for each shard. + """ + + shard_perfs = [] + device_bw = kernel_bw_lookup( + compute_device, + compute_kernel, + hbm_mem_bw, + ddr_mem_bw, + hbm_to_ddr_mem_bw, + caching_ratio, + prefetch_pipeline, + ) + if device_bw is None: + raise PlannerError( + f"No kernel bandwidth exists for this combo of compute device: {compute_device}, compute kernel: {compute_kernel}" + ) -def perf_func_emb_wall_time( - shard_sizes: List[List[int]], - compute_kernel: str, - compute_device: str, - sharding_type: str, - batch_sizes: List[int], - world_size: int, - local_world_size: int, - input_lengths: List[float], - input_data_type_size: float, - output_data_type_size: float, - num_poolings: List[float], - bw_intra_host: float, - bw_inter_host: float, - is_pooled: bool, - is_weighted: bool = False, - has_feature_processor: bool = False, - caching_ratio: Optional[float] = None, - is_inference: bool = False, -) -> List[float]: - """ - Attempts to model perfs as a function of relative wall times. + for hash_size, emb_dim in shard_sizes: + if ( + sharding_type == ShardingType.TABLE_WISE.value + or sharding_type == ShardingType.COLUMN_WISE.value + or sharding_type == ShardingType.TABLE_COLUMN_WISE.value + ): + shard_perf = cls._get_tw_sharding_perf( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + emb_dim=emb_dim, + input_data_type_size=input_data_type_size, + table_data_type_size=table_data_type_size, + output_data_type_size=output_data_type_size, + fwd_a2a_comm_data_type_size=fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size=bwd_a2a_comm_data_type_size, + num_poolings=num_poolings, + hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw, + device_bw=device_bw, + comms_bandwidths=comms_bandwidths, + bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + is_pooled=is_pooled, + is_weighted=is_weighted, + is_inference=is_inference, + expected_cache_fetches=expected_cache_fetches, + ) + elif sharding_type == ShardingType.ROW_WISE.value: + shard_perf = cls._get_rw_sharding_perf( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + emb_dim=emb_dim, + input_data_type_size=input_data_type_size, + table_data_type_size=table_data_type_size, + output_data_type_size=output_data_type_size, + fwd_a2a_comm_data_type_size=fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size=bwd_a2a_comm_data_type_size, + fwd_sr_comm_data_type_size=fwd_sr_comm_data_type_size, + bwd_sr_comm_data_type_size=bwd_sr_comm_data_type_size, + num_poolings=num_poolings, + hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw, + device_bw=device_bw, + comms_bandwidths=comms_bandwidths, + bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + is_pooled=is_pooled, + is_weighted=is_weighted, + expected_cache_fetches=expected_cache_fetches, + is_inference=is_inference, + ) + elif ( + sharding_type == ShardingType.TABLE_ROW_WISE.value + or sharding_type == ShardingType.GRID_SHARD.value + ): + shard_perf = cls._get_twrw_sharding_perf( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + emb_dim=emb_dim, + input_data_type_size=input_data_type_size, + table_data_type_size=table_data_type_size, + output_data_type_size=output_data_type_size, + fwd_a2a_comm_data_type_size=fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size=bwd_a2a_comm_data_type_size, + fwd_sr_comm_data_type_size=fwd_sr_comm_data_type_size, + bwd_sr_comm_data_type_size=bwd_sr_comm_data_type_size, + num_poolings=num_poolings, + hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw, + device_bw=device_bw, + comms_bandwidths=comms_bandwidths, + bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + is_pooled=is_pooled, + is_weighted=is_weighted, + expected_cache_fetches=expected_cache_fetches, + ) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + shard_perf = cls._get_dp_sharding_perf( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + grad_num_elem=hash_size * emb_dim, + emb_dim=emb_dim, + input_data_type_size=input_data_type_size, + table_data_type_size=table_data_type_size, + output_data_type_size=output_data_type_size, + num_poolings=num_poolings, + device_bw=device_bw, + comms_bandwidths=comms_bandwidths, + bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + is_pooled=is_pooled, + is_weighted=is_weighted, + ) + else: + raise ValueError( + f"Unrecognized or unsupported sharding type provided: {sharding_type}" + ) + shard_perfs.append(shard_perf) + + return shard_perfs + + @classmethod + def _get_expected_cache_prefetch_time( + cls, + hbm_to_ddr_mem_bw: float, + expected_cache_fetches: float, + emb_dim: int, + table_data_type_size: float, + ) -> float: + # TODO: validate cost model with empirical test + prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size + return prefetch_bytes / hbm_to_ddr_mem_bw + + @classmethod + def _get_tw_sharding_perf( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + table_data_type_size: float, + output_data_type_size: float, + fwd_a2a_comm_data_type_size: float, + bwd_a2a_comm_data_type_size: float, + num_poolings: List[float], + hbm_to_ddr_mem_bw: float, + device_bw: float, + comms_bandwidths: GeneralizedCommsBandwidth, + bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, + is_pooled: bool, + is_weighted: bool = False, + is_inference: bool = False, + expected_cache_fetches: float = 0, + ) -> Perf: - Args: - shard_sizes (List[List[int]]): the list of (local_rows, local_cols) of each - shard. - compute_kernel (str): compute kernel. - compute_device (str): compute device. - sharding_type (str): tw, rw, cw, twrw, dp. - batch_sizes (List[int]): batch size for each input feature. - world_size (int): the number of devices for all hosts. - local_world_size (int): the number of the device for each host. - input_lengths (List[float]): the list of the average number of lookups of each - input query feature. - input_data_type_size (float): the data type size of the distributed - data_parallel input. - output_data_type_size (float): the data type size of the distributed - data_parallel output. - num_poolings (List[float]): number of poolings per sample, typically 1.0. - bw_intra_host (float): the bandwidth within a single host like multiple threads. - bw_inter_host (float): the bandwidth between two hosts like multiple machines. - is_pooled (bool): True if embedding output is pooled (ie. EmbeddingBag), False - if unpooled/sequential (ie. Embedding). - is_weighted (bool = False): if the module is an EBC and is weighted, typically - signifying an id score list feature. - is_inference (bool = False): if planning for inference. - has_feature_processor (bool = False): if the module has a feature processor. - caching_ratio (Optional[float] = None): cache ratio to determine the bandwidth - of device. + batch_inputs = sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + batch_outputs = ( + sum([x * y for x, y in zip(num_poolings, batch_sizes)]) + if is_pooled + else batch_inputs + ) - Returns: - List[float]: the list of perf for each shard. - """ + input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) + if is_weighted: + input_read_size *= 2 - shard_perfs = [] - device_bw = kernel_bw_lookup(compute_device, compute_kernel, caching_ratio) - if device_bw is None: - raise PlannerError( - f"No kernel bandwidth exists for this combo of compute device: {compute_device}, compute kernel: {compute_kernel}" + # minimum embedding dim is set to 32 due to kernel usage + embedding_lookup_size = ( + batch_inputs * world_size * max(emb_dim, 32) * table_data_type_size ) - for hash_size, emb_dim in shard_sizes: - if ( - sharding_type == ShardingType.TABLE_WISE.value - or sharding_type == ShardingType.COLUMN_WISE.value - or sharding_type == ShardingType.TABLE_COLUMN_WISE.value - ): - shard_perf = _get_tw_sharding_perf( - batch_sizes=batch_sizes, - world_size=world_size, - local_world_size=local_world_size, - input_lengths=input_lengths, - emb_dim=emb_dim, - input_data_type_size=input_data_type_size, - output_data_type_size=output_data_type_size, - num_poolings=num_poolings, - device_bw=device_bw, - bw_inter_host=bw_inter_host, - bw_intra_host=bw_intra_host, - is_pooled=is_pooled, - is_weighted=is_weighted, - is_inference=is_inference, - has_feature_processor=has_feature_processor, - ) - elif sharding_type == ShardingType.ROW_WISE.value: - shard_perf = _get_rw_sharding_perf( - batch_sizes=batch_sizes, - world_size=world_size, - local_world_size=local_world_size, - input_lengths=input_lengths, - emb_dim=emb_dim, - input_data_type_size=input_data_type_size, - output_data_type_size=output_data_type_size, - num_poolings=num_poolings, - device_bw=device_bw, - bw_inter_host=bw_inter_host, - bw_intra_host=bw_intra_host, - is_pooled=is_pooled, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - ) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: - shard_perf = _get_twrw_sharding_perf( - batch_sizes=batch_sizes, - world_size=world_size, - local_world_size=local_world_size, - input_lengths=input_lengths, - emb_dim=emb_dim, - input_data_type_size=input_data_type_size, - output_data_type_size=output_data_type_size, - num_poolings=num_poolings, - device_bw=device_bw, - bw_inter_host=bw_inter_host, - bw_intra_host=bw_intra_host, - is_pooled=is_pooled, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - ) - elif sharding_type == ShardingType.DATA_PARALLEL.value: - shard_perf = _get_dp_sharding_perf( - batch_sizes=batch_sizes, - world_size=world_size, - local_world_size=local_world_size, - input_lengths=input_lengths, - grad_num_elem=hash_size * emb_dim, - emb_dim=emb_dim, - input_data_type_size=output_data_type_size, - output_data_type_size=output_data_type_size, - num_poolings=num_poolings, - device_bw=device_bw, - bw_inter_host=bw_inter_host, - is_pooled=is_pooled, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - ) - else: - raise ValueError( - f"Unrecognized or unsupported sharding type provided: {sharding_type}" - ) - shard_perfs.append(shard_perf) - - return shard_perfs + fwd_output_write_size = ( + batch_outputs * world_size * emb_dim * fwd_a2a_comm_data_type_size + ) + bwd_output_write_size = ( + batch_outputs * world_size * emb_dim * bwd_a2a_comm_data_type_size + ) + # embedding dim below 128 will reduce kernel efficency + block_usage_penalty = 1 + if emb_dim < FULL_BLOCK_EMB_DIM: + if emb_dim >= 64: + block_usage_penalty = HALF_BLOCK_PENALTY + else: # emb_dim >= 32 + block_usage_penalty = QUARTER_BLOCK_PENALTY + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_TO_ALL, + ) + fwd_comms = fwd_output_write_size / comms_bw -def _get_tw_sharding_perf( - batch_sizes: List[int], - world_size: int, - local_world_size: int, - input_lengths: List[float], - emb_dim: int, - input_data_type_size: float, - output_data_type_size: float, - num_poolings: List[float], - device_bw: float, - bw_inter_host: float, - bw_intra_host: float, - is_pooled: bool, - is_weighted: bool = False, - is_inference: bool = False, - has_feature_processor: bool = False, -) -> float: - batch_inputs = sum( - [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] - ) - batch_outputs = ( - sum([x * y for x, y in zip(num_poolings, batch_sizes)]) - if is_pooled - else batch_inputs - ) + fwd_compute = ( + (input_read_size + embedding_lookup_size + fwd_output_write_size) + * block_usage_penalty + / device_bw + ) + if is_inference: + # only consider forward compute and comms for inference + return Perf( + fwd_compute=fwd_compute, fwd_comms=fwd_comms, bwd_compute=0, bwd_comms=0 + ) - input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) - if is_weighted or has_feature_processor: - input_read_size *= 2 + bwd_comms = bwd_output_write_size / comms_bw - # minimum embedding dim is set to 32 due to kernel usage - embedding_lookup_size = ( - batch_inputs * world_size * max(emb_dim, 32) * output_data_type_size - ) + bwd_grad_indice_weights_kernel = ( + fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0 + ) - output_write_size = batch_outputs * world_size * emb_dim * output_data_type_size + # includes fused optimizers + bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier - # embedding dim below 128 will reduce kernel efficency - block_usage_penalty = 1 - if emb_dim < FULL_BLOCK_EMB_DIM: - if emb_dim >= 64: - block_usage_penalty = HALF_BLOCK_PENALTY - else: # emb_dim >= 32 - block_usage_penalty = QUARTER_BLOCK_PENALTY + prefetch_compute = cls._get_expected_cache_prefetch_time( + hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size + ) - comms_bw = bw_inter_host if world_size > local_world_size else bw_intra_host - fwd_comms = output_write_size / comms_bw + # in order of model parallel execution, starting with: + # BWD DP -> BWD MP ... FWD MP -> FWD DP + return Perf( + fwd_compute=fwd_compute, + fwd_comms=fwd_comms, + bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, + bwd_comms=bwd_comms, + prefetch_compute=prefetch_compute, + ) - fwd_compute = ( - (input_read_size + embedding_lookup_size + output_write_size) - * block_usage_penalty - / device_bw - ) - if is_inference: - # only consider forward compute and comms for inference - return fwd_compute + fwd_comms + @classmethod + def _get_rw_sharding_perf( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + table_data_type_size: float, + output_data_type_size: float, + fwd_a2a_comm_data_type_size: float, + bwd_a2a_comm_data_type_size: float, + fwd_sr_comm_data_type_size: float, + bwd_sr_comm_data_type_size: float, + num_poolings: List[float], + hbm_to_ddr_mem_bw: float, + device_bw: float, + comms_bandwidths: GeneralizedCommsBandwidth, + bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, + is_pooled: bool, + is_weighted: bool = False, + expected_cache_fetches: float = 0, + is_inference: bool = False, + ) -> Perf: + batch_inputs = ( + sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + / world_size + ) + batch_outputs = ( + sum([x * y for x, y in zip(num_poolings, batch_sizes)]) + if is_pooled + else batch_inputs + ) - bwd_comms = fwd_comms + input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) + if is_weighted: + input_read_size *= 2 - bwd_grad_indice_weights_kernel = ( - fwd_compute * WEIGHTED_KERNEL_MULTIPLIER - if is_weighted or has_feature_processor - else 0 - ) + embedding_lookup_size = ( + batch_inputs * world_size * emb_dim * table_data_type_size + ) - # includes fused optimizers - bwd_compute = fwd_compute * BWD_COMPUTE_MULTIPLIER + fwd_output_write_size = ( + batch_outputs * world_size * emb_dim * fwd_sr_comm_data_type_size + if is_pooled + else batch_outputs * world_size * emb_dim * fwd_a2a_comm_data_type_size + ) + bwd_output_write_size = ( + batch_outputs * world_size * emb_dim * bwd_sr_comm_data_type_size + if is_pooled + else batch_outputs * world_size * emb_dim * bwd_a2a_comm_data_type_size + ) + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.REDUCE_SCATTER, + ) - # in order of model parallel execution, starting with: - # BWD DP -> BWD MP ... FWD MP -> FWD DP - return ( - bwd_comms - + bwd_grad_indice_weights_kernel - + bwd_compute - + fwd_compute - + fwd_comms - ) + fwd_comms = fwd_output_write_size / comms_bw + fwd_compute = ( + input_read_size + embedding_lookup_size + fwd_output_write_size + ) / device_bw -def _get_rw_sharding_perf( - batch_sizes: List[int], - world_size: int, - local_world_size: int, - input_lengths: List[float], - emb_dim: int, - input_data_type_size: float, - output_data_type_size: float, - num_poolings: List[float], - device_bw: float, - bw_inter_host: float, - bw_intra_host: float, - is_pooled: bool, - is_weighted: bool = False, - has_feature_processor: bool = False, -) -> float: - batch_inputs = ( - sum([x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]) - / world_size - ) - batch_outputs = ( - sum([x * y for x, y in zip(num_poolings, batch_sizes)]) - if is_pooled - else batch_inputs - ) + if is_inference: + # only consider forward compute and comms for inference + return Perf( + fwd_compute=fwd_compute, fwd_comms=fwd_comms, bwd_compute=0, bwd_comms=0 + ) + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_GATHER, + ) + bwd_comms = bwd_output_write_size / comms_bw - input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) - if is_weighted or has_feature_processor: - input_read_size *= 2 + bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw - embedding_lookup_size = batch_inputs * world_size * emb_dim * output_data_type_size + bwd_grad_indice_weights_kernel = ( + fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0 + ) - output_write_size = batch_outputs * world_size * emb_dim * output_data_type_size + bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier - comms_bw = bw_inter_host if world_size > local_world_size else bw_intra_host - fwd_comms = output_write_size / comms_bw + # for row-wise, expected_cache_fetches per shard is / world_size + prefetch_compute = cls._get_expected_cache_prefetch_time( + hbm_to_ddr_mem_bw, + expected_cache_fetches / world_size, + emb_dim, + table_data_type_size, + ) - fwd_compute = ( - input_read_size + embedding_lookup_size + output_write_size - ) / device_bw + return Perf( + fwd_compute=fwd_compute, + fwd_comms=fwd_comms, + bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, + bwd_comms=bwd_comms + bwd_batched_copy, + prefetch_compute=prefetch_compute, + ) - bwd_comms = fwd_comms + @classmethod + def _get_twrw_sharding_perf( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + table_data_type_size: float, + output_data_type_size: float, + fwd_a2a_comm_data_type_size: float, + bwd_a2a_comm_data_type_size: float, + fwd_sr_comm_data_type_size: float, + bwd_sr_comm_data_type_size: float, + num_poolings: List[float], + hbm_to_ddr_mem_bw: float, + device_bw: float, + comms_bandwidths: GeneralizedCommsBandwidth, + bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, + is_pooled: bool, + is_weighted: bool = False, + expected_cache_fetches: float = 0, + ) -> Perf: + batch_inputs = ( + sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + / local_world_size + ) + batch_outputs = ( + sum([x * y for x, y in zip(num_poolings, batch_sizes)]) + if is_pooled + else batch_inputs + ) - bwd_batched_copy = output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw + input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) + if is_weighted: + input_read_size *= 2 - bwd_grad_indice_weights_kernel = ( - fwd_compute * WEIGHTED_KERNEL_MULTIPLIER - if is_weighted or has_feature_processor - else 0 - ) + embedding_lookup_size = ( + batch_inputs * world_size * emb_dim * table_data_type_size + ) - bwd_compute = fwd_compute * BWD_COMPUTE_MULTIPLIER + fwd_output_write_size = ( + batch_outputs * world_size * emb_dim * fwd_sr_comm_data_type_size + ) + bwd_output_write_size = ( + batch_outputs * world_size * emb_dim * bwd_sr_comm_data_type_size + ) + comms_bw = comms_bandwidths.get_bw( + world_size=local_world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.REDUCE_SCATTER, + ) - return ( - bwd_comms - + bwd_batched_copy - + bwd_grad_indice_weights_kernel - + bwd_compute - + fwd_compute - + fwd_comms - ) + # intra host comm + fwd_comms = fwd_output_write_size / comms_bw + + # inter host comm + if world_size > local_world_size: + inter_host_fwd_output_write_size = ( + batch_outputs + * ( + world_size / local_world_size + ) # this is the size of the procees group. + * emb_dim + * fwd_a2a_comm_data_type_size + ) + comms_bw = comms_bandwidths.get_bw( + world_size=int(world_size / local_world_size), + local_world_size=1, + collective_type=CollectiveType.ALL_TO_ALL, + ) + fwd_comms += inter_host_fwd_output_write_size / comms_bw + fwd_compute = ( + input_read_size + embedding_lookup_size + fwd_output_write_size + ) / device_bw -def _get_twrw_sharding_perf( - batch_sizes: List[int], - world_size: int, - local_world_size: int, - input_lengths: List[float], - emb_dim: int, - input_data_type_size: float, - output_data_type_size: float, - num_poolings: List[float], - device_bw: float, - bw_inter_host: float, - bw_intra_host: float, - is_pooled: bool, - is_weighted: bool = False, - has_feature_processor: bool = False, -) -> float: - batch_inputs = ( - sum([x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]) - / local_world_size - ) - batch_outputs = ( - sum([x * y for x, y in zip(num_poolings, batch_sizes)]) - if is_pooled - else batch_inputs - ) + # intra host comm (i.e. all gather) + comms_bw = comms_bandwidths.get_bw( + world_size=local_world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_GATHER, + ) + bwd_comms = bwd_output_write_size / comms_bw + + # inter host comm (i.e. all to all) + if world_size > local_world_size: + inter_host_bwd_output_write_size = ( + batch_outputs + * ( + world_size / local_world_size + ) # this is the size of the procress group. + * emb_dim + * bwd_a2a_comm_data_type_size + ) + comms_bw = comms_bandwidths.get_bw( + world_size=int(world_size / local_world_size), + local_world_size=1, + collective_type=CollectiveType.ALL_TO_ALL, + ) + bwd_comms += inter_host_bwd_output_write_size / comms_bw - input_read_size = math.ceil(batch_inputs * world_size * input_data_type_size) - if is_weighted or has_feature_processor: - input_read_size *= 2 + bwd_grad_indice_weights_kernel = ( + fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0 + ) - embedding_lookup_size = batch_inputs * world_size * emb_dim * output_data_type_size + bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw - output_write_size = batch_outputs * world_size * emb_dim * output_data_type_size + bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier - fwd_comms = output_write_size / bw_intra_host + # for table-wise-row-wise or grid_shard, expected_cache_fetches per shard is / local_world_size + prefetch_compute = cls._get_expected_cache_prefetch_time( + hbm_to_ddr_mem_bw, + expected_cache_fetches / local_world_size, + emb_dim, + table_data_type_size, + ) - if world_size > local_world_size: - fwd_comms += output_write_size * (local_world_size / world_size) / bw_inter_host + return Perf( + fwd_compute=fwd_compute, + fwd_comms=fwd_comms, + bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, + bwd_comms=bwd_comms + bwd_batched_copy, + prefetch_compute=prefetch_compute, + ) - fwd_compute = ( - input_read_size + embedding_lookup_size + output_write_size - ) / device_bw + @classmethod + def _get_dp_sharding_perf( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + input_lengths: List[float], + grad_num_elem: int, + emb_dim: int, + input_data_type_size: float, + table_data_type_size: float, + output_data_type_size: float, + num_poolings: List[float], + device_bw: float, + comms_bandwidths: GeneralizedCommsBandwidth, + bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, + is_pooled: bool, + is_weighted: bool = False, + ) -> Perf: + batch_inputs = sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + batch_outputs = ( + sum([x * y for x, y in zip(num_poolings, batch_sizes)]) + if is_pooled + else batch_inputs + ) - bwd_comms = fwd_comms + input_read_size = math.ceil(batch_inputs * input_data_type_size) + if is_weighted: + input_read_size *= 2 - bwd_grad_indice_weights_kernel = ( - fwd_compute * WEIGHTED_KERNEL_MULTIPLIER - if is_weighted or has_feature_processor - else 0 - ) + embedding_lookup_size = batch_inputs * emb_dim * table_data_type_size - bwd_batched_copy = output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw + output_write_size = batch_outputs * emb_dim * table_data_type_size + table_size = grad_num_elem * table_data_type_size - bwd_compute = fwd_compute * BWD_COMPUTE_MULTIPLIER + fwd_compute = ( + input_read_size + embedding_lookup_size + output_write_size + ) / device_bw - return ( - bwd_comms - + bwd_batched_copy - + bwd_grad_indice_weights_kernel - + bwd_compute - + fwd_compute - + fwd_comms - ) + num_nodes = min(world_size / local_world_size, 2) + # all-reduce data transfer: https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_REDUCE, + ) + all_reduce = table_size * (2 * num_nodes - 1) / num_nodes / comms_bw + # inter host communication constraint + if world_size > 2 * local_world_size: + all_reduce *= 2 -def _get_dp_sharding_perf( - batch_sizes: List[int], - world_size: int, - local_world_size: int, - input_lengths: List[float], - grad_num_elem: int, - emb_dim: int, - input_data_type_size: float, - output_data_type_size: float, - num_poolings: List[float], - device_bw: float, - bw_inter_host: float, - is_pooled: bool, - is_weighted: bool = False, - has_feature_processor: bool = False, -) -> float: - batch_inputs = sum( - [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] - ) - batch_outputs = ( - sum([x * y for x, y in zip(num_poolings, batch_sizes)]) - if is_pooled - else batch_inputs - ) + # SGD + Fill + BUnary + optimizer_kernels = table_size * DP_ELEMENTWISE_KERNELS_PERF_FACTOR / device_bw - input_read_size = math.ceil(batch_inputs * input_data_type_size) - if is_weighted or has_feature_processor: - input_read_size *= 2 + bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier - embedding_lookup_size = batch_inputs * emb_dim * output_data_type_size + bwd_grad_indice_weights_kernel = ( + fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0 + ) - output_write_size = batch_outputs * emb_dim * output_data_type_size - table_size = grad_num_elem * output_data_type_size + # TODO(T170641643): we don't model prefetch_compute for data parallel yet, see + # comment in perf_func_emb_wall_time() regarding expected_cache_fetches calculation. + return Perf( + fwd_compute=fwd_compute, + fwd_comms=0, + bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, + bwd_comms=all_reduce + optimizer_kernels, + ) - fwd_compute = ( - input_read_size + embedding_lookup_size + output_write_size - ) / device_bw - num_nodes = min(world_size / local_world_size, 2) +def _extract_comm_data_type_size( + sharder: ModuleSharder[nn.Module], sharding_option: ShardingOption +) -> Tuple[float, float, float, float]: + table_data_type_size = sharding_option.tensor.element_size() - # all-reduce data transfer: https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf - all_reduce = ( - table_size - * (2 * num_nodes - 1) - / num_nodes - / (bw_inter_host * local_world_size) # 1 NIC per GPU - ) - # inter host communication constraint - if world_size > 2 * local_world_size: - all_reduce *= 2 + fwd_a2a_comm_data_type_size = table_data_type_size + bwd_a2a_comm_data_type_size = table_data_type_size + fwd_sr_comm_data_type_size = table_data_type_size + bwd_sr_comm_data_type_size = table_data_type_size - # SGD + Fill + BUnary - optimizer_kernels = table_size * DP_ELEMENTWISE_KERNELS_PERF_FACTOR / device_bw + if sharder.qcomm_codecs_registry is not None: + qcomm_codecs_registry = sharder.qcomm_codecs_registry + if ( + sharding_option.is_pooled + and CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name in qcomm_codecs_registry + ): + codecs = sharder.qcomm_codecs_registry[ + CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name + ] + fwd_a2a_comm_data_type_size = torch.tensor( + [], dtype=codecs.forward.quantized_dtype + ).element_size() + bwd_a2a_comm_data_type_size = torch.tensor( + [], dtype=codecs.backward.quantized_dtype + ).element_size() - bwd_compute = fwd_compute * BWD_COMPUTE_MULTIPLIER + if ( + not sharding_option.is_pooled + and CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name in qcomm_codecs_registry + ): + codecs = qcomm_codecs_registry[CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name] + fwd_a2a_comm_data_type_size = torch.tensor( + [], dtype=codecs.forward.quantized_dtype + ).element_size() + bwd_a2a_comm_data_type_size = torch.tensor( + [], dtype=codecs.backward.quantized_dtype + ).element_size() - bwd_grad_indice_weights_kernel = ( - fwd_compute * WEIGHTED_KERNEL_MULTIPLIER - if is_weighted or has_feature_processor - else 0 - ) + if ( + sharding_option.is_pooled + and CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name in qcomm_codecs_registry + ): + codecs = qcomm_codecs_registry[CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name] + fwd_sr_comm_data_type_size = torch.tensor( + [], dtype=codecs.forward.quantized_dtype + ).element_size() + bwd_sr_comm_data_type_size = torch.tensor( + [], dtype=codecs.backward.quantized_dtype + ).element_size() return ( - all_reduce - + optimizer_kernels - + bwd_grad_indice_weights_kernel - + bwd_compute - + fwd_compute + fwd_a2a_comm_data_type_size, + bwd_a2a_comm_data_type_size, + fwd_sr_comm_data_type_size, + bwd_sr_comm_data_type_size, ) class EmbeddingStorageEstimator(ShardEstimator): """ Embedding Storage Usage Estimator + + Args: + topology (Topology): device topology. + constraints (Optional[Dict[str, ParameterConstraints]]): parameter constraints. + pipeline_type (PipelineType): The type of pipeline, if any. Will determine the + input replication factor during memory estimation. + run_embedding_at_peak_memory (bool): If the embedding fwd/bwd will be execute when HBM + usage is at peak. When set to TRUE, any temporary memory allocation during + embedding forward/backward, as long as output sizes before output_dist will + be counted towards HBM storage cost. Otherwise they won't since they'll be + "hidden" by the real memory peak. + + Only take effect if pipeline_type is set for backward compatibility (not affecting + models using old pipeline-agnostic formula) + + Default to false because this is typically false for RecSys since memory + peak happens at the end of dense forwrad / beginning of dense backward instead. + is_inference (bool): If the model is inference model. Default to False. """ def __init__( self, topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, + pipeline_type: PipelineType = PipelineType.NONE, + run_embedding_at_peak_memory: bool = False, + is_inference: bool = False, ) -> None: self._topology = topology self._constraints = constraints + self._pipeline_type = pipeline_type + self._run_embedding_at_peak_memory = run_embedding_at_peak_memory + self._is_inference = is_inference def estimate( self, sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None, ) -> None: + """ + Estimate the storage cost of each sharding option. + + Args: + sharding_options (List[ShardingOption]): list of sharding options. + sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]): map from module + type to sharder. + """ if not sharder_map: assert not sharding_options, "sharder_map not provided for sharding_options" return @@ -581,11 +987,16 @@ def estimate( for sharding_option in sharding_options: sharder_key = sharder_name(type(sharding_option.module[1])) sharder = sharder_map[sharder_key] - caching_ratio = ( - sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] - if hasattr(sharder, "fused_params") and sharder.fused_params - else None - ) + + caching_ratio = sharding_option.cache_load_factor + # TODO: remove after deprecating fused_params in sharder + if caching_ratio is None: + caching_ratio = ( + sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + num_poolings = ( cast(List[float], self._constraints[sharding_option.name].num_poolings) if self._constraints @@ -602,6 +1013,28 @@ def estimate( else [sharding_option.batch_size] * sharding_option.num_inputs ) + # hardcoded as 8 bytes + # input indices can be of int32, but in TBE they get converted to int64 anyway + input_data_type_size = BIGINT_DTYPE + + output_data_type_size: float = ( + DATA_TYPE_NUM_BITS[sharding_option.output_dtype] / 8 + if sharding_option.output_dtype + else sharding_option.tensor.element_size() + ) + + mpp_conf = ( + sharding_option.cache_params.multipass_prefetch_config + if sharding_option.cache_params + else None + ) + # TODO: remove after deprecating fused_params in sharder + if mpp_conf is None: + mpp_conf = ( + sharder.fused_params.get("multipass_prefetch_config", None) + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) shard_storages = calculate_shard_storages( sharder=sharder, sharding_type=sharding_option.sharding_type, @@ -616,12 +1049,60 @@ def estimate( num_poolings=num_poolings, caching_ratio=caching_ratio if caching_ratio else UVM_CACHING_RATIO, is_pooled=sharding_option.is_pooled, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + pipeline_type=self._pipeline_type, + count_ephemeral_storage_cost=self._run_embedding_at_peak_memory, + is_inference=self._is_inference, + multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None, ) - for shard, storage in zip(sharding_option.shards, shard_storages): shard.storage = storage +def calculate_pipeline_io_cost( + input_size: int, + output_size: int, + prefetch_size: int, + pipeline_type: PipelineType, + multipass_prefetch_max_pass: Optional[int], + count_ephemeral_storage_cost: bool = False, + is_inference: bool = False, +) -> int: + # These magical number comes from heuristical analysis of memory snapshot during + # pipelining, and are subject to the actual implementation. + # + # Now it's static to make memory estimation more sane for UVM offloading, + # we need to make this estimation more blackbox-based. + if is_inference: + return 0 + + # Output size is considered ephemeral storage cost since they are temporarily + # only during all2all and won't last long (e.g. from fwd to bwd) + output_contribition_to_peak_memory = ( + output_size if count_ephemeral_storage_cost else 0 + ) + + if pipeline_type == PipelineType.TRAIN_SPARSE_DIST: + pipelining_hbm_input_factor = 2 + return ( + pipelining_hbm_input_factor * input_size + + output_contribition_to_peak_memory + ) + if pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST: + multipass_prefetch_max_pass = multipass_prefetch_max_pass or 1 + pipelining_hbm_input_factor = 3 + prefetch_bursty_hbm_input_factor = 1 + 6 / multipass_prefetch_max_pass + return ( + pipelining_hbm_input_factor * input_size + + int(prefetch_bursty_hbm_input_factor * prefetch_size) + + output_contribition_to_peak_memory + ) + + # Catch all case, for backward compatibility + return input_size + output_size + + def calculate_shard_storages( sharder: ModuleSharder[nn.Module], sharding_type: str, @@ -636,6 +1117,12 @@ def calculate_shard_storages( num_poolings: List[float], caching_ratio: float, is_pooled: bool, + input_data_type_size: float, + output_data_type_size: float, + pipeline_type: PipelineType = PipelineType.NONE, + count_ephemeral_storage_cost: bool = False, + is_inference: bool = False, + multipass_prefetch_max_pass: Optional[int] = None, ) -> List[Storage]: """ Calculates estimated storage sizes for each sharded tensor, comprised of input, @@ -656,16 +1143,16 @@ def calculate_shard_storages( num_poolings (List[float]): average number of poolings per sample (typically 1.0). caching_ratio (float): ratio of HBM to DDR memory for UVM caching. - is_pooled (bool): True if embedding output is pooled (ie. EmbeddingBag), False - if unpooled/sequential (ie. Embedding). + is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), False + if unpooled/sequential (ie. `Embedding`). + input_data_type_size (int): number of bytes of input data type. + output_data_type_size (int): number of bytes of output data type. + pipeline_type: PipelineType: pipeline type if for training. + is_inference: bool, whether the model is for inference. Returns: List[Storage]: storage object for each device in topology. """ - - input_data_type_size = BIGINT_DTYPE - output_data_type_size = tensor.element_size() - input_sizes, output_sizes = _calculate_shard_io_sizes( sharding_type=sharding_type, batch_sizes=batch_sizes, @@ -684,13 +1171,18 @@ def calculate_shard_storages( hbm_storage: int = tensor_storage.get("hbm", 0) ddr_storage: int = tensor_storage.get("ddr", 0) + table_cached: bool = False if compute_kernel in { EmbeddingComputeKernel.FUSED_UVM_CACHING.value, EmbeddingComputeKernel.QUANT_UVM_CACHING.value, + EmbeddingComputeKernel.KEY_VALUE.value, }: hbm_storage = round(ddr_storage * caching_ratio) + table_cached = True + if compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}: + ddr_storage = 0 - optimizer_class = getattr(tensor, "_optimizer_class", None) + optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0] hbm_specific_sizes: List[int] = _calculate_storage_specific_sizes( storage=hbm_storage, @@ -698,6 +1190,8 @@ def calculate_shard_storages( shard_sizes=shard_sizes, sharding_type=sharding_type, optimizer_class=optimizer_class, + is_inference=is_inference, + clf=caching_ratio if table_cached else None, ) ddr_specific_sizes: List[int] = _calculate_storage_specific_sizes( storage=ddr_storage, @@ -705,10 +1199,24 @@ def calculate_shard_storages( shard_sizes=shard_sizes, sharding_type=sharding_type, optimizer_class=optimizer_class, + is_inference=is_inference, ) hbm_sizes: List[int] = [ - input_size + output_size + hbm_specific_size if compute_device == "cuda" else 0 + ( + hbm_specific_size + + calculate_pipeline_io_cost( + input_size=input_size, + output_size=output_size, + prefetch_size=input_size if table_cached else 0, + pipeline_type=pipeline_type, + multipass_prefetch_max_pass=multipass_prefetch_max_pass, + count_ephemeral_storage_cost=count_ephemeral_storage_cost, + is_inference=is_inference, + ) + if compute_device == "cuda" + else 0 + ) for input_size, output_size, hbm_specific_size in zip( input_sizes, output_sizes, @@ -716,9 +1224,11 @@ def calculate_shard_storages( ) ] ddr_sizes: List[int] = [ - input_size + output_size + ddr_specific_size - if compute_device == "cpu" - else ddr_specific_size + ( + input_size + output_size + ddr_specific_size + if compute_device in {"cpu", "mtia"} and not is_inference + else ddr_specific_size + ) for input_size, output_size, ddr_specific_size in zip( input_sizes, output_sizes, @@ -743,8 +1253,8 @@ def _calculate_shard_io_sizes( input_lengths: List[float], emb_dim: int, shard_sizes: List[List[int]], - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -795,7 +1305,10 @@ def _calculate_shard_io_sizes( num_poolings=num_poolings, is_pooled=is_pooled, ) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + elif ( + sharding_type == ShardingType.TABLE_ROW_WISE.value + or sharding_type == ShardingType.GRID_SHARD.value # same as table row wise + ): return _calculate_twrw_shard_io_sizes( batch_sizes=batch_sizes, world_size=world_size, @@ -818,8 +1331,8 @@ def _calculate_dp_shard_io_sizes( input_lengths: List[float], emb_dim: int, num_shards: int, - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -845,8 +1358,8 @@ def _calculate_tw_shard_io_sizes( world_size: int, input_lengths: List[float], emb_dim: int, - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -872,8 +1385,8 @@ def _calculate_cw_shard_io_sizes( world_size: int, input_lengths: List[float], shard_sizes: List[List[int]], - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -904,8 +1417,8 @@ def _calculate_rw_shard_io_sizes( world_size: int, input_lengths: List[float], shard_sizes: List[List[int]], - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -920,17 +1433,21 @@ def _calculate_rw_shard_io_sizes( ) input_sizes = [ - math.ceil(batch_inputs * world_size * input_data_type_size) - if prod(shard) != 0 - else 0 + ( + math.ceil(batch_inputs * world_size * input_data_type_size) + if prod(shard) != 0 + else 0 + ) for shard in shard_sizes ] output_sizes = [ - math.ceil( - batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ( + math.ceil( + batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ) + if prod(shard) != 0 + else 0 ) - if prod(shard) != 0 - else 0 for i, shard in enumerate(shard_sizes) ] @@ -943,8 +1460,8 @@ def _calculate_twrw_shard_io_sizes( local_world_size: int, input_lengths: List[float], shard_sizes: List[List[int]], - input_data_type_size: int, - output_data_type_size: int, + input_data_type_size: float, + output_data_type_size: float, num_poolings: List[float], is_pooled: bool, ) -> Tuple[List[int], List[int]]: @@ -959,17 +1476,21 @@ def _calculate_twrw_shard_io_sizes( ) input_sizes = [ - math.ceil(batch_inputs * world_size * input_data_type_size) - if prod(shard) != 0 - else 0 + ( + math.ceil(batch_inputs * world_size * input_data_type_size) + if prod(shard) != 0 + else 0 + ) for shard in shard_sizes ] output_sizes = [ - math.ceil( - batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ( + math.ceil( + batch_outputs * world_size * shard_sizes[i][1] * output_data_type_size + ) + if prod(shard) != 0 + else 0 ) - if prod(shard) != 0 - else 0 for i, shard in enumerate(shard_sizes) ] @@ -982,11 +1503,15 @@ def _calculate_storage_specific_sizes( shard_sizes: List[List[int]], sharding_type: str, optimizer_class: Optional[Type[torch.optim.Optimizer]] = None, + is_inference: bool = False, + clf: Optional[float] = None, ) -> List[int]: tensor_sizes: List[int] = [ - math.ceil(storage * prod(size) / prod(shape)) - if sharding_type != ShardingType.DATA_PARALLEL.value - else storage + ( + math.ceil(storage * prod(size) / prod(shape)) + if sharding_type != ShardingType.DATA_PARALLEL.value + else storage + ) for size in shard_sizes ] optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape) @@ -995,9 +1520,24 @@ def _calculate_storage_specific_sizes( math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes ] + # If a table has turned on UVM caching (meaning clf is not None), there'll be + # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to + # cache aux state (note that this is not the cache content itself) + cache_aux_state_sizes: List[int] = ( + [0] * len(shard_sizes) + if clf is None + else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes] + ) + return [ - tensor_size + optimizer_size - for tensor_size, optimizer_size in zip(tensor_sizes, optimizer_sizes) + ( + cache_state_size + tensor_size + optimizer_size + if not is_inference + else tensor_size + ) + for cache_state_size, tensor_size, optimizer_size in zip( + cache_aux_state_sizes, tensor_sizes, optimizer_sizes + ) ] @@ -1015,3 +1555,85 @@ def _get_optimizer_multipler( return 1 / shape[-1] else: return 1 + + +class EmbeddingOffloadStats(CacheStatistics): + """Computes cache statistics for uvm_fused_cache tables. + + Args: + + cachebility (float): + The area-under-the-curve of miss-ratio curve. + expected_lookups (float): + The expected number of unique embedding ids per global batch. + mrc_hist_counts (torch.Tensor): + A 1d tensor (size n) holding a histogram of LRU miss ratio curve. Each bin + represents 1/nth of possible LRU cache sizes (from load_factor 0 to load_factor + 1.0). The bin contains the number of expected LRU operations that could be + handled without a cache miss if the LRU load_factor was at least that size. + height (int): + The height (num_embeddings) of the embedding table. + """ + + def __init__( + self, + cacheability: float, + expected_lookups: int, + mrc_hist_counts: torch.Tensor, + height: int, + ) -> None: + self._cacheability = cacheability + self._expected_lookups = expected_lookups + self.height = height + + if mrc_hist_counts.dim() != 1: + raise ValueError(f"expected 1d tensor, got {mrc_hist_counts.dim()}d") + if mrc_hist_counts.size()[0] == 0: + raise ValueError("expected non-empty tensor") + + self.hist: torch.Tensor = mrc_hist_counts + self.bins: torch.Tensor = torch.linspace(0, height, len(mrc_hist_counts) + 1) + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + cache_size = torch.tensor(clf * self.height) + miss_rate = EmbeddingOffloadStats.estimate_cache_miss_rate( + cache_sizes=cache_size, hist=self.hist, bins=self.bins + ) + return miss_rate.item() + + @property + def cacheability(self) -> float: + return self._cacheability + + @staticmethod + def estimate_cache_miss_rate( + cache_sizes: torch.Tensor, hist: torch.Tensor, bins: torch.Tensor + ) -> torch.Tensor: + """Calculate estimated cache miss ratio for the proposed cache_sizes, given the MRC + histogram. + """ + ys = hist.cumsum(dim=0) + if ys[-1] == 0: + # feature has no usage data -> no cache misses + return torch.zeros_like(cache_sizes, dtype=torch.float32) + ys = ys / ys[-1] # rescale [0,1] + ys = 1 - ys # make miss-ratio, not hit-ratio + + # torch.bucketize has slightly different semantics to np.digitize, + # and np.digitize has a complex interface, read the docs carefully! + # we're trying to reverse the ops of np.histogram, indices are one larger than + # the insert positions, since with right=True, index returned such that x < + # bins[index], so x 'lives' in hist[index-1] + # A cache size of k will get hits for all stack distances of upto k-1 inclusive. + larger_bin_indices = torch.bucketize(cache_sizes - 1, bins, right=True) + # Augment ys to deal with torch.bucketize boundary conditions: + # values outside of bins range map to 0, or len(bins). + # So we extend ys to populate sentinel values for these cases. With the twist that + # the left-hand sentinel we put on the right side of the array, as larger_bin_indices - 1 + # maps 0 -> -1, which pytorch maps to most right hand value. + ys = torch.cat((ys, torch.tensor([0.0, 1.0]))) + return ys[larger_bin_indices - 1] diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index af64e0856..430e5c916 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -5,28 +5,58 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import copy import logging +import math +import statistics from collections import defaultdict -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) + +from torch import nn -from torchrec.distributed.planner.constants import BIGINT_DTYPE +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.constants import BIGINT_DTYPE, NUM_POOLINGS from torchrec.distributed.planner.shard_estimators import _calculate_shard_io_sizes from torchrec.distributed.planner.storage_reservations import ( - FixedPercentageReservation, + FixedPercentageStorageReservation, HeuristicalStorageReservation, InferenceStorageReservation, ) from torchrec.distributed.planner.types import ( + CriticalPathEstimate, ParameterConstraints, + Perf, ShardingOption, Stats, Storage, StorageReservation, Topology, ) -from torchrec.distributed.planner.utils import bytes_to_gb, bytes_to_mb -from torchrec.distributed.types import ParameterSharding, ShardingPlan, ShardingType - +from torchrec.distributed.planner.utils import ( + _find_imbalance_tables, + bytes_to_gb, + bytes_to_mb, + sharder_name as get_sharder_name, +) +from torchrec.distributed.types import ( + ModuleSharder, + ParameterSharding, + ShardingPlan, + ShardingType, +) logger: logging.Logger = logging.getLogger(__name__) @@ -34,6 +64,78 @@ MIN_WIDTH = 90 +def _normalize_float(p: List[float]) -> List[float]: + p_total = sum(p) + assert p_total > 0 + return [p_i / p_total for p_i in p] + + +def _normalize_int(p: List[int]) -> List[float]: + p_total = sum(p) + assert p_total > 0 + return [p_i * 1.0 / p_total for p_i in p] + + +def _total_variation(p: List[float]) -> float: + k = len(p) + assert k > 0 + return max(abs(pi - 1.0 / k) for pi in p) + + +def _total_distance(p: List[float]) -> float: + k = len(p) + assert k > 0 + return sum(abs(pi - 1.0 / k) for pi in p) + + +def _chi_sq_divergence(p: List[float]) -> float: + k = len(p) + assert k > 0 + return sum(abs(pi - 1.0 / k) ** 2.0 * k for pi in p) + + +def _kl_divergence(p: List[float]) -> float: + k = len(p) + assert k > 0 + return sum(pi * math.log(k * pi) for pi in p if pi > 0) + + +def _calc_max_chi_sq_divergence(N: int) -> float: + # Upper bound for chi-sq divergence in our case given sample size of distribution (N) + assert N > 0 + return (((N - 1) / N) ** 2.0) * N + (N - 1) * (1 / N) + + +def _calc_max_kl_divergence(N: int) -> float: + # Upper bound for KL divergence in our case given sample size of distribution (N) + assert N > 0 + return math.log(N) + + +def _normalized_kl_divergence(p: List[float]) -> float: + k = len(p) + assert k > 0 + # Max val can be 0 if world size is 1 (e.g. local run) + max_val = _calc_max_kl_divergence(k) + return _kl_divergence(p) / max_val if max_val > 0 else 0.0 + + +def _normalized_chi_sq_divergence(p: List[float]) -> float: + k = len(p) + assert k > 0 + # Max val can be 0 if world size is 1 (e.g. local run) + max_val = _calc_max_chi_sq_divergence(k) + return _chi_sq_divergence(p) / max_val if max_val > 0 else 0.0 + + +IMBALANCE_STAT_MEASURE: Dict[str, Tuple[Callable[..., float], Dict[str, Any]]] = { + "Total Variation": (_total_variation, {}), + "Total Distance": (_total_distance, {}), + "Chi Divergence": (_normalized_chi_sq_divergence, {}), + "KL Divergence": (_normalized_kl_divergence, {}), +} + + class EmbeddingStats(Stats): """ Stats for a sharding planner execution. @@ -54,6 +156,7 @@ def log( run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, debug: bool = True, ) -> None: """ @@ -66,7 +169,6 @@ def log( sharding_plan (ShardingPlan): sharding plan chosen by the planner. topology (Topology): device topology. batch_size (int): batch size. - storage_constraint (Topology): available storage after storage reservation. storage_reservation (StorageReservation): reserves storage for unsharded parts of the model num_proposals (int): number of proposals evaluated. @@ -81,6 +183,7 @@ def log( shard_by_fqn = { module_name + "." + param_name: value for module_name, param_dict in sharding_plan.plan.items() + # pyre-ignore - this is a EmbeddingShardingPlan below for param_name, value in param_dict.items() } stats: Dict[int, Dict[str, Any]] = { @@ -90,43 +193,23 @@ def log( used_sharding_types = set() compute_kernels_to_count = defaultdict(int) + compute_kernels_to_storage = defaultdict(lambda: Storage(0, 0)) - reserved_percent = ( - storage_reservation._percentage - if isinstance( - storage_reservation, - ( - FixedPercentageReservation, - HeuristicalStorageReservation, - InferenceStorageReservation, - ), - ) - else 0.0 - ) - dense_storage = ( - storage_reservation._dense_storage - if isinstance( - storage_reservation, - (HeuristicalStorageReservation, InferenceStorageReservation), - ) - and storage_reservation._dense_storage is not None - else Storage(0, 0) - ) - assert dense_storage - kjt_storage = ( - storage_reservation._kjt_storage - if isinstance( - storage_reservation, - (HeuristicalStorageReservation, InferenceStorageReservation), - ) - and storage_reservation._kjt_storage - else Storage(0, 0) + reserved_hbm_percent, dense_storage, kjt_storage = _compute_storage( + storage_reservation=storage_reservation ) - assert kjt_storage for sharding_option in best_plan: fqn = sharding_option.fqn + compute_kernels_to_count[sharding_option.compute_kernel] += 1 + compute_kernels_to_storage[ + sharding_option.compute_kernel + ] += sharding_option.total_storage + + # for shard in sharding_option.shards: + # compute_kernels_to_storage[sharding_option.compute_kernel] += shard.hbm + if shard_by_fqn.get(fqn) is None: continue shard: ParameterSharding = shard_by_fqn[fqn] @@ -140,7 +223,6 @@ def log( ) sharding_type_abbr = _get_sharding_type_abbr(shard.sharding_type) used_sharding_types.add(sharding_type_abbr) - compute_kernels_to_count[sharding_option.compute_kernel] += 1 for i, rank in enumerate(ranks): count = stats[rank]["type"].get(sharding_type_abbr, 0) @@ -148,135 +230,30 @@ def log( stats[rank]["input_sizes"] += input_sizes[i] stats[rank]["output_sizes"] += output_sizes[i] - used_hbm = [0] * topology.world_size - used_ddr = [0] * topology.world_size - perf = [0.0] * topology.world_size - for sharding_option in best_plan: - for shard in sharding_option.shards: - shard_storage = cast(Storage, shard.storage) - rank = cast(int, shard.rank) - used_hbm[rank] += shard_storage.hbm - used_ddr[rank] += shard_storage.ddr - perf[rank] += cast(float, shard.perf) - - used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in used_hbm] - used_ddr = [ddr + dense_storage.ddr + kjt_storage.ddr for ddr in used_ddr] - - table: List[List[Union[str, int]]] = [ - [ - "Rank", - "HBM (GB)", - "DDR (GB)", - "Perf (ms)", - "Input (MB)", - "Output (MB)", - "Shards", - ], - [ - "------", - "----------", - "----------", - "-----------", - "------------", - "-------------", - "--------", - ], - ] - - for rank, device in enumerate(topology.devices): - used_hbm_gb = bytes_to_gb(used_hbm[rank]) - used_hbm_ratio = ( - used_hbm[rank] / ((1 - reserved_percent) * device.storage.hbm) - if topology.compute_device == "cuda" - else 0 - ) - used_ddr_gb = bytes_to_gb(used_ddr[rank]) - used_ddr_ratio = ( - used_ddr[rank] / ((1 - reserved_percent) * device.storage.ddr) - if device.storage.ddr > 0 - else 0 - ) - for sharding_type in used_sharding_types: - if sharding_type not in stats[rank]["type"]: - stats[rank]["type"][sharding_type] = 0 + used_hbm, used_ddr, perf = _compute_mem_usage_and_perf( + topology=topology, + best_plan=best_plan, + dense_storage=dense_storage, + kjt_storage=kjt_storage, + ) - rank_hbm = f"{round(used_hbm_gb, 1)} ({used_hbm_ratio:.0%})" - rank_ddr = f"{round(used_ddr_gb, 1)} ({used_ddr_ratio:.0%})" - rank_perf = f"{round(perf[rank], 3)}" - rank_input = f"{round(stats[rank]['input_sizes'], 2)}" - rank_output = f"{round(stats[rank]['output_sizes'], 2)}" - rank_shards = " ".join( - f"{sharding_type}: {num_tables}" - for sharding_type, num_tables in sorted(stats[rank]["type"].items()) - ) - table.append( - [ - rank, - rank_hbm, - rank_ddr, - rank_perf, - rank_input, - rank_output, - rank_shards, - ] - ) - formatted_table = _format_table(table) - self._width = max(self._width, len(formatted_table[0]) + 8) + formatted_table = self._log_rank_mem_usage_and_perf( + topology=topology, + used_hbm=used_hbm, + used_ddr=used_ddr, + perf=perf, + stats=stats, + used_sharding_types=used_sharding_types, + reserved_hbm_percent=reserved_hbm_percent, + ) if debug: - param_table: List[List[Union[str, int]]] = [ - [ - "FQN", - "Sharding", - "Compute Kernel", - "Perf (ms)", - "Pooling Factor", - "Output", - "Features", - "Emb Dim", - "Hash Size", - "Ranks", - ], - [ - "-----", - "----------", - "----------------", - "-----------", - "----------------", - "--------", - "----------", - "--------", - "-----------", - "-------", - ], - ] - for so in best_plan: - ranks = sorted([cast(int, shard.rank) for shard in so.shards]) - ranks = _collapse_consecutive_ranks(ranks) - shard_perfs = str( - round(sum([cast(float, shard.perf) for shard in so.shards]), 3) - ) - pooling_factor = str(round(sum(so.input_lengths), 3)) - output = "pooled" if so.is_pooled else "sequence" - num_features = len(so.input_lengths) - embedding_dim = so.tensor.shape[1] - hash_size = so.tensor.shape[0] - param_table.append( - [ - so.fqn, - _get_sharding_type_abbr(so.sharding_type), - so.compute_kernel, - shard_perfs, - pooling_factor, - output, - num_features, - embedding_dim, - hash_size, - ",".join(ranks), - ] - ) - formatted_param_table = _format_table(param_table) - self._width = max(self._width, len(formatted_param_table[0]) + 6) + formatted_param_table = self._log_sharding_plan( + best_plan=best_plan, + sharding_plan=sharding_plan, + sharders=sharders, + constraints=constraints, + ) self._stats_table.clear() self._stats_table.append("#" * self._width) @@ -284,7 +261,7 @@ def log( self._stats_table.append(f"#{header_text: ^{self._width-2}}#") iter_text = ( - f"--- Evalulated {num_proposals} proposal(s), " + f"--- Evaluated {num_proposals} proposal(s), " f"found {num_plans} possible plan(s), " f"ran for {run_time:.2f}s ---" ) @@ -293,14 +270,19 @@ def log( divider = "-" * (self._width - 4) self._stats_table.append(f"#{divider: ^{self._width-2}}#") - for row in formatted_table: - self._stats_table.append(f"# {row: <{self._width-3}}#") + if sharding_plan.plan: + for row in formatted_table: + self._stats_table.append(f"# {row: <{self._width-3}}#") - legend = "Input: MB/iteration, Output: MB/iteration, Shards: number of tables" - hbm_info = "HBM: estimated peak memory usage for shards, dense tensors, and features (KJT)" - self._stats_table.append(f"#{'' : ^{self._width-2}}#") - self._stats_table.append(f"# {legend: <{self._width-3}}#") - self._stats_table.append(f"# {hbm_info: <{self._width-3}}#") + perf_breakdown = "Perf: Total perf (Forward compute, Forward comms, Backward compute, Backward comms, Prefetch compute)" + legend = ( + "Input: MB/iteration, Output: MB/iteration, Shards: number of tables" + ) + hbm_info = "HBM: estimated peak memory usage for shards, dense tensors, and features (KJT)" + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append(f"# {perf_breakdown: <{self._width-3}}#") + self._stats_table.append(f"# {legend: <{self._width-3}}#") + self._stats_table.append(f"# {hbm_info: <{self._width-3}}#") if debug: self._stats_table.append(f"#{'' : ^{self._width-2}}#") @@ -312,17 +294,42 @@ def log( self._stats_table.append(f"#{'' : ^{self._width-2}}#") self._stats_table.append(f"# {batch_size_text : <{self._width-3}}#") - self._log_compute_kernel_stats(compute_kernels_to_count) + if not sharding_plan.plan: + rank_size_text = f"World Size: {topology.world_size}" + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append(f"# {rank_size_text : <{self._width-3}}#") + + self._log_compute_kernel_stats( + compute_kernels_to_count, description="Compute Kernels Count" + ) + self._log_compute_kernel_stats( + { + k: f"HBM: {round(bytes_to_gb(s.hbm),3)} GB, DDR: {round(bytes_to_gb(s.ddr),3)} GB" + for k, s in compute_kernels_to_storage.items() + }, + description="Compute Kernels Storage", + ) if debug: - self._log_max_perf_and_max_hbm(perf, used_hbm) + if sharding_plan.plan: + # Plan imbalance stats for perf and storage + self._log_plan_imbalance_stats( + perf, + used_hbm, + used_ddr, + ) + + # Max perf and HBM to help root cause imbalance + self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan) self._log_storage_reservation_stats( storage_reservation, topology, - reserved_percent, + reserved_hbm_percent, dense_storage, kjt_storage, ) + if sharding_plan.plan: + self._log_imbalance_tables(best_plan) self._stats_table.append("#" * self._width) @@ -386,48 +393,214 @@ def _get_shard_stats( return ranks, input_sizes, output_sizes - def _log_max_perf_and_max_hbm(self, perf: List[float], used_hbm: List[int]) -> None: - max_perf = max(perf) - max_perf_indices = [i for i in range(len(perf)) if perf[i] == max_perf] - rank_text = "ranks" if len(max_perf_indices) > 1 else "rank" - max_perf_indices = _collapse_consecutive_ranks(max_perf_indices) - max_perf_ranks = f"{rank_text} {','.join(max_perf_indices)}" - longest_critical_path = ( - f"Longest Critical Path: {round(max_perf, 3)} ms on {max_perf_ranks}" + def _log_dist_imbalance_stats( + self, + normalized_dist: List[float], + ) -> None: + for name, (measure, kwargs) in IMBALANCE_STAT_MEASURE.items(): + result_txt = f"{name}: {measure(normalized_dist, **kwargs):.3f}" + self._stats_table.append(f"# {result_txt : <{self._width-3}}#") + + def _log_plan_imbalance_stats( + self, perf: List[Perf], used_hbm: List[int], used_ddr: List[int] + ) -> None: + imbalance_logged = False + total_perfs = [perf_i.total for perf_i in perf] + + # Can extend with fwd/bwd perfs if needed + perf_dists = [ + ("Total", total_perfs), + ] + + for name, perf_dist in perf_dists: + if sum(perf_dist) > 0: + imbalance_logged = True + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append( + f"# {name + ' Perf Imbalance Statistics' : <{self._width-3}}#" + ) + normalized_perf_dist = _normalize_float(perf_dist) + self._log_dist_imbalance_stats(normalized_perf_dist) + + if sum(used_hbm) > 0: + imbalance_logged = True + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append( + f"# {'HBM Imbalance Statistics' : <{self._width-3}}#" + ) + normalized_used_hbm = _normalize_int(used_hbm) + self._log_dist_imbalance_stats(normalized_used_hbm) + + if sum(used_ddr) > 0: + imbalance_logged = True + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append( + f"# {'DDR Imbalance Statistics' : <{self._width-3}}#" + ) + normalized_used_ddr = _normalize_int(used_ddr) + self._log_dist_imbalance_stats(normalized_used_ddr) + + if imbalance_logged: + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append( + f"# {'Imbalance stats range 0-1, higher means more imbalanced' : <{self._width-3}}#" + ) + + def _log_max_perf_and_max_hbm( + self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption] + ) -> None: + total_perfs = [perf.total for perf in perfs] + + max_total_perf_text = ( + f"Maximum of Total Perf: {_generate_max_text(total_perfs)}" + ) + + mean_total_perf = statistics.mean(total_perfs) + mean_total_perf_text = f"Mean Total Perf: {round(mean_total_perf,3)} ms" + + max_total_perf = max(total_perfs) + + total_perf_delta_pct = 0.0 + if mean_total_perf > 0.0: + total_perf_delta_pct = ( + (max_total_perf - mean_total_perf) / mean_total_perf * 100 + ) + + total_perf_delta_text = ( + f"Max Total Perf is {total_perf_delta_pct:.3g}% greater than the mean" + ) + + max_fwd_compute_perf_text = f"Maximum of Forward Compute: {_generate_max_text([perf.fwd_compute for perf in perfs])}" + max_fwd_comms_perf_text = f"Maximum of Forward Comms: {_generate_max_text([perf.fwd_comms for perf in perfs])}" + max_bwd_compute_perf_text = f"Maximum of Backward Compute: {_generate_max_text([perf.bwd_compute for perf in perfs])}" + max_bwd_comms_perf_text = f"Maximum of Backward Comms: {_generate_max_text([perf.bwd_comms for perf in perfs])}" + max_prefetch_compute_perf_text = f"Maximum of Prefetch Compute: {_generate_max_text([perf.prefetch_compute for perf in perfs])}" + + sum_of_maxima = ( + max(perf.fwd_compute for perf in perfs) + + max(perf.fwd_comms for perf in perfs) + + max(perf.bwd_compute for perf in perfs) + + max(perf.bwd_comms for perf in perfs) + + max(perf.prefetch_compute for perf in perfs) + ) + sum_of_maxima_text = f"Sum of Maxima: {round(sum_of_maxima, 3)} ms" + + critical_path_estimate = _calculate_critical_path(best_plan) + + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append(f"# {max_total_perf_text : <{self._width-3}}#") + self._stats_table.append(f"# {mean_total_perf_text : <{self._width-3}}#") + self._stats_table.append(f"# {total_perf_delta_text : <{self._width-3}}#") + self._stats_table.append(f"# {max_fwd_compute_perf_text : <{self._width-3}}#") + self._stats_table.append(f"# {max_fwd_comms_perf_text : <{self._width-3}}#") + self._stats_table.append(f"# {max_bwd_compute_perf_text : <{self._width-3}}#") + self._stats_table.append(f"# {max_bwd_comms_perf_text : <{self._width-3}}#") + self._stats_table.append( + f"# {max_prefetch_compute_perf_text : <{self._width-3}}#" ) + self._stats_table.append(f"# {sum_of_maxima_text : <{self._width-3}}#") + self._stats_table.append(f"#{'' : ^{self._width-2}}#") - self._stats_table.append(f"# {longest_critical_path : <{self._width-3}}#") - - max_hbm = max(used_hbm) - max_hbm_indices = [i for i in range(len(used_hbm)) if used_hbm[i] == max_hbm] - rank_text = "ranks" if len(max_hbm_indices) > 1 else "rank" - max_hbm_indices = _collapse_consecutive_ranks(max_hbm_indices) - max_hbm_ranks = f"{rank_text} {','.join(max_hbm_indices)}" - peak_memory_pressure = f"Peak Memory Pressure: {round(bytes_to_gb(max_hbm), 3)} GB on {max_hbm_ranks}" + self._stats_table.append( + f"# {'Estimated Sharding Distribution' : <{self._width-2}}#" + ) + self._stats_table.append( + f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Mean HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.mean) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Low Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_low) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Critical Path (comms): '+str(round(critical_path_estimate.comms_estimate, 3)) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Critical Path (compute): '+str(round(critical_path_estimate.comp_estimate, 3)) : <{self._width-3}}#" + ) + self._stats_table.append( + f"# {'Critical Path (comms + compute): '+str(round(critical_path_estimate.total(), 3)) : <{self._width-3}}#" + ) + + max_used_hbm = max(used_hbm) + mean_used_hbm = statistics.mean(used_hbm) + hbm_delta_pct = 0.0 + if mean_used_hbm > 0.0: + hbm_delta_pct = (max_used_hbm - mean_used_hbm) / mean_used_hbm * 100 + hbm_delta_text = f"Max HBM is {hbm_delta_pct:.3g}% greater than the mean" + self._stats_table.append(f"# {hbm_delta_text : <{self._width-3}}#") + self._stats_table.append(f"#{'' : ^{self._width-2}}#") - self._stats_table.append(f"# {peak_memory_pressure : <{self._width-3}}#") + per_rank_hbm = copy.copy(used_hbm) + NUM_PEAK_RANK = 5 + peak_memory_pressure = [] + + top_hbm_usage_estimation = f"Top HBM Memory Usage Estimation: {round(bytes_to_gb(max(used_hbm)), 3)} GB" + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + self._stats_table.append(f"# {top_hbm_usage_estimation : <{self._width-3}}#") + + for top in range(NUM_PEAK_RANK): + if not per_rank_hbm: + break + max_hbm = max(per_rank_hbm) + max_hbm_indices = [ + i + for i in range(len(per_rank_hbm)) + if math.isclose( + bytes_to_mb(per_rank_hbm[i]), bytes_to_mb(max_hbm), abs_tol=1.0 + ) + ] + rank_text = "ranks" if len(max_hbm_indices) > 1 else "rank" + max_hbm_indices = _collapse_consecutive_ranks(max_hbm_indices) + max_hbm_ranks = f"{rank_text} {','.join(max_hbm_indices)}" + peak_memory_pressure.append( + f"Top Tier #{top+1} Estimated Peak HBM Pressure: {round(bytes_to_gb(max_hbm), 3)} GB on {max_hbm_ranks}" + ) + per_rank_hbm = [ + hbm + for hbm in per_rank_hbm + if not math.isclose(bytes_to_mb(hbm), bytes_to_mb(max_hbm), abs_tol=1.0) + ] + + for peak_rank in reversed(peak_memory_pressure): + self._stats_table.append(f"# {peak_rank : <{self._width-3}}#") def _log_storage_reservation_stats( self, storage_reservation: StorageReservation, topology: Topology, - reserved_percent: float, + reserved_hbm_percent: float, dense_storage: Storage, kjt_storage: Storage, ) -> None: device_storage = topology.devices[0].storage usable_hbm = round( - bytes_to_gb(int((1 - reserved_percent) * device_storage.hbm)), 3 + bytes_to_gb(int((1 - reserved_hbm_percent) * device_storage.hbm)), 3 ) - usable_ddr = round( - bytes_to_gb(int((1 - reserved_percent) * device_storage.ddr)), 3 + reserved_hbm = round( + bytes_to_gb(int(reserved_hbm_percent * device_storage.hbm)), 3 ) + reserved_memory = f"HBM: {reserved_hbm} GB" + reserved_hbm_percentage = f"Percent of Total HBM: {reserved_hbm_percent:.0%}" + usable_ddr = round(bytes_to_gb(int(device_storage.ddr)), 3) usable_memory = f"HBM: {usable_hbm} GB, DDR: {usable_ddr} GB" - usable_percentage = f"Percent of Total: {(1 - reserved_percent):.0%}" + usable_hbm_percentage = ( + f"Percent of Total HBM: {(1 - reserved_hbm_percent):.0%}" + ) self._stats_table.append(f"#{'' : ^{self._width-2}}#") - self._stats_table.append(f"# {'Usable Memory:' : <{self._width-3}}#") + self._stats_table.append(f"# {'Reserved Memory:' : <{self._width-3}}#") + self._stats_table.append(f"# {reserved_memory : <{self._width-6}}#") + self._stats_table.append(f"# {reserved_hbm_percentage : <{self._width-6}}#") + self._stats_table.append(f"# {'Planning Memory:' : <{self._width-3}}#") self._stats_table.append(f"# {usable_memory : <{self._width-6}}#") - self._stats_table.append(f"# {usable_percentage : <{self._width-6}}#") + self._stats_table.append(f"# {usable_hbm_percentage : <{self._width-6}}#") if isinstance(storage_reservation, HeuristicalStorageReservation): dense_hbm = round(bytes_to_gb(dense_storage.hbm), 3) @@ -448,18 +621,298 @@ def _log_storage_reservation_stats( ) self._stats_table.append(f"# {kjt_storage_text : <{self._width-6}}#") + def _log_imbalance_tables(self, best_plan: List[ShardingOption]) -> None: + self._stats_table.append(f"#{'' : ^{self._width-2}}#") + perf_imbalance_tables = _find_imbalance_tables(best_plan) + hbm_imbalance_tables = _find_imbalance_tables(best_plan, target_imbalance="hbm") + self._stats_table.append( + f"# {'Top 5 Tables Causing Max Perf:' : <{self._width-3}}#" + ) + for sharding_option in perf_imbalance_tables[0:5]: + self._stats_table.append(f"# {sharding_option.name : <{self._width-6}}#") + self._stats_table.append( + f"# {'Top 5 Tables Causing Max HBM:' : <{self._width-3}}#" + ) + for sharding_option in hbm_imbalance_tables[0:5]: + storage = sharding_option.shards[0].storage + assert storage is not None # linter friendly optional check + + rank_text = "ranks" if len(sharding_option.shards) > 1 else "rank" + top_table = ( + f"{sharding_option.name}: {round(bytes_to_gb(storage.hbm),3)} GB on {rank_text} " + f"{[shard.rank for shard in sharding_option.shards]}" + ) + self._stats_table.append(f"# {top_table : <{self._width-6}}#") + def _log_compute_kernel_stats( - self, compute_kernels_to_count: Dict[str, int] + self, compute_kernels_stats: Dict[str, Any], description: str ) -> None: compute_kernels_count = [ f"{compute_kernel}: {count}" - for compute_kernel, count in sorted(compute_kernels_to_count.items()) + for compute_kernel, count in sorted(compute_kernels_stats.items()) ] self._stats_table.append(f"#{'' : ^{self._width-2}}#") - self._stats_table.append(f"# {'Compute Kernels:' : <{self._width-3}}#") + self._stats_table.append(f"# {description+':' : <{self._width-3}}#") for compute_kernel_count in compute_kernels_count: self._stats_table.append(f"# {compute_kernel_count : <{self._width-6}}#") + def _log_rank_mem_usage_and_perf( + self, + topology: Topology, + used_ddr: List[int], + used_hbm: List[int], + perf: List[Perf], + stats: Dict[int, Dict[str, Any]], + used_sharding_types: Set[str], + reserved_hbm_percent: float, + ) -> List[str]: + table: List[List[Union[str, int]]] = [ + [ + "Rank", + "HBM (GB)", + "DDR (GB)", + "Perf (ms)", + "Input (MB)", + "Output (MB)", + "Shards", + ], + [ + "------", + "----------", + "----------", + "-----------", + "------------", + "-------------", + "--------", + ], + ] + + for rank, device in enumerate(topology.devices): + used_hbm_gb = bytes_to_gb(used_hbm[rank]) + used_hbm_ratio = ( + used_hbm[rank] / ((1 - reserved_hbm_percent) * device.storage.hbm) + if topology.compute_device == "cuda" + and ((1 - reserved_hbm_percent) * device.storage.hbm) != 0 + else 0 + ) + used_ddr_gb = bytes_to_gb(used_ddr[rank]) + used_ddr_ratio = ( + used_ddr[rank] / device.storage.ddr if device.storage.ddr > 0 else 0 + ) + for sharding_type in used_sharding_types: + if sharding_type not in stats[rank]["type"]: + stats[rank]["type"][sharding_type] = 0 + + rank_hbm = f"{round(used_hbm_gb, 3)} ({used_hbm_ratio:.0%})" + rank_ddr = f"{round(used_ddr_gb, 3)} ({used_ddr_ratio:.0%})" + rank_perf = _format_perf_breakdown(perf[rank]) + rank_input = f"{round(stats[rank]['input_sizes'], 2)}" + rank_output = f"{round(stats[rank]['output_sizes'], 2)}" + rank_shards = " ".join( + f"{sharding_type}: {num_tables}" + for sharding_type, num_tables in sorted(stats[rank]["type"].items()) + ) + table.append( + [ + rank, + rank_hbm, + rank_ddr, + rank_perf, + rank_input, + rank_output, + rank_shards, + ] + ) + formatted_table = _format_table(table) + self._width = max(self._width, len(formatted_table[0]) + 8) + return formatted_table + + def _log_sharding_plan( + self, + best_plan: List[ShardingOption], + sharding_plan: ShardingPlan, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + ) -> List[str]: + def _get_embedding_dim(so: ShardingOption) -> str: + embedding_dim = ( + f"{so.tensor.shape[1]} ({so.shards[0].size[1]})" + if so.sharding_type == ShardingType.COLUMN_WISE.value + or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value + or so.sharding_type == ShardingType.GRID_SHARD.value + else f"{so.tensor.shape[1]}" + ) + return embedding_dim + + def _get_num_poolings( + constraints: Optional[Dict[str, ParameterConstraints]], so: ShardingOption + ) -> List[float]: + num_poolings = ( + cast(List[float], constraints[so.name].num_poolings) + if constraints + and constraints.get(so.name) + and constraints[so.name].num_poolings + else [NUM_POOLINGS] * len(so.input_lengths) + ) + return num_poolings + + def _get_cache_load_factor( + sharder: Optional[ModuleSharder[nn.Module]], so: ShardingOption + ) -> str: + sharder_cache_load_factor = ( + sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + cache_load_factor = "None" + # Surfacing cache load factor does not make sense if not using uvm caching. + if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: + cache_load_factor = str( + so.cache_load_factor + if so.cache_load_factor is not None + else sharder_cache_load_factor + ) + return cache_load_factor + + param_table = [ + [ + "FQN", + "Sharding", + "Compute Kernel", + "Perf (ms)", + "Storage (HBM, DDR)", + "Cache Load Factor", + "Sum Pooling Factor", + "Sum Num Poolings", + "Num Indices", + "Output", + "Weighted", + "Sharder", + "Features", + "Emb Dim (CW Dim)", + "Hash Size", + "Ranks", + ], + [ + "-----", # FQN + "----------", # Sharding + "----------------", # Compute Kernel + "-----------", # Perf (ms) + "--------------------", # Storage (HBM, DDR) + "-------------------", # Cache Load Factor + "--------------------", # Sum Pooling Factor + "------------------", # Sum Num Poolings + "-------------", # Num Indices + "--------", # Output + "----------", # Weighted + "---------", # Sharder + "----------", # Features + "------------------", # Emb Dim (CW Dim) + "-----------", # Hash Size + "-------", # Ranks + ], + ] + feat_batch_sizes = [ + ( + constraints[so.name].batch_sizes + if constraints and constraints.get(so.name) + else None + ) + for so in best_plan + ] + + sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + get_sharder_name(sharder.module_type): sharder + # pyre-ignore - this is a ModuleSharder below + for sharder in sharders + if sharders + } + + if include_batch_sizes := any(feat_batch_sizes): + param_table[0].append("Batch Sizes") + param_table[1].append("-------------") + for i, so in enumerate(best_plan): + ranks = sorted([cast(int, shard.rank) for shard in so.shards]) + ranks = _collapse_consecutive_ranks(ranks) + + so_perf = Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) + for shard in so.shards: + so_perf += cast(Perf, shard.perf) + + shard_perfs = _format_perf_breakdown(so_perf) + + so_storage = Storage(hbm=0, ddr=0) + for shard in so.shards: + so_storage += cast(Storage, shard.storage) + + shard_storages = _format_storage_breakdown(so_storage) + + pooling_factor = str(round(sum(so.input_lengths), 3)) + num_poolings = _get_num_poolings(constraints, so) + num_indices = str( + round(sum(x * y for x, y in zip(so.input_lengths, num_poolings)), 3) + ) + num_poolings = str(round(sum(num_poolings), 3)) + output = "pooled" if so.is_pooled else "sequence" + weighted = "weighted" if so.is_weighted else "unweighted" + sharder = sharder_map.get(get_sharder_name(type(so.module[1])), None) + sharder_name = type(sharder).__name__ + num_features = len(so.input_lengths) + embedding_dim = _get_embedding_dim(so) + cache_load_factor = _get_cache_load_factor(sharder, so) + hash_size = so.tensor.shape[0] + param_table.append( + # pyre-ignore[6] + [ + so.fqn, + _get_sharding_type_abbr(so.sharding_type), + so.compute_kernel, + shard_perfs, + shard_storages, + cache_load_factor, + pooling_factor, + num_poolings, + num_indices, + output, + weighted, + sharder_name, + num_features, + embedding_dim, + hash_size, + ",".join(ranks) if sharding_plan.plan else "None", + ] + ) + if include_batch_sizes: + bs = feat_batch_sizes[i] + param_table[-1].append(_reduce_int_list(bs) if bs else "n/a") + formatted_param_table = _format_table(param_table) # pyre-ignore[6] + self._width = max(self._width, len(formatted_param_table[0]) + 6) + return formatted_param_table + + +def _generate_rank_hbm_stats( + per_rank_hbm: List[int], func: Callable[[Iterable[float]], float] +) -> str: + stats = round(func(per_rank_hbm)) + stats_indicies = [ + i + for i in range(len(per_rank_hbm)) + if math.isclose(bytes_to_mb(per_rank_hbm[i]), bytes_to_mb(stats), abs_tol=1.0) + ] + rank_text = "ranks" if len(stats_indicies) > 1 else "rank" + return f"{round(bytes_to_gb(stats), 3)} GB on {rank_text} {stats_indicies}" + + +def _generate_max_text(perfs: List[float]) -> str: + max_perf = max(perfs) + + max_perf_indices = [i for i in range(len(perfs)) if perfs[i] == max_perf] + rank_text = "ranks" if len(max_perf_indices) > 1 else "rank" + max_perf_indices = _collapse_consecutive_ranks(max_perf_indices) + max_perf_ranks = f"{rank_text} {','.join(max_perf_indices)}" + + return f"{round(max_perf, 3)} ms on {max_perf_ranks}" + def _get_sharding_type_abbr(sharding_type: str) -> str: if sharding_type == ShardingType.DATA_PARALLEL.value: @@ -474,12 +927,103 @@ def _get_sharding_type_abbr(sharding_type: str) -> str: return "TWRW" elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: return "TWCW" + elif sharding_type == ShardingType.GRID_SHARD.value: + return "GS" else: raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" ) +def _format_perf_breakdown(perf: Perf) -> str: + breakdown = [ + perf.fwd_compute, + perf.fwd_comms, + perf.bwd_compute, + perf.bwd_comms, + perf.prefetch_compute, + ] + breakdown_string = ",".join( + [str(round(num)) if num >= 1 else round_to_one_sigfig(num) for num in breakdown] + ) + + return f"{str(round(perf.total, 3))} ({breakdown_string})" + + +def _compute_storage( + storage_reservation: StorageReservation, +) -> Tuple[float, Storage, Storage]: + reserved_hbm_percent = ( + storage_reservation._percentage + if isinstance( + storage_reservation, + ( + FixedPercentageStorageReservation, + HeuristicalStorageReservation, + InferenceStorageReservation, + ), + ) + else 0.0 + ) + + dense_storage = ( + storage_reservation._dense_storage + if isinstance( + storage_reservation, + (HeuristicalStorageReservation, InferenceStorageReservation), + ) + and storage_reservation._dense_storage is not None + else Storage(0, 0) + ) + assert dense_storage + kjt_storage = ( + storage_reservation._kjt_storage + if isinstance( + storage_reservation, + (HeuristicalStorageReservation, InferenceStorageReservation), + ) + and storage_reservation._kjt_storage + else Storage(0, 0) + ) + assert kjt_storage + return reserved_hbm_percent, dense_storage, kjt_storage + + +def _compute_mem_usage_and_perf( + topology: Topology, + best_plan: List[ShardingOption], + dense_storage: Storage, + kjt_storage: Storage, +) -> Tuple[List[int], List[int], List[Perf]]: + used_hbm = [0] * topology.world_size + used_ddr = [0] * topology.world_size + perf = [ + Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) + for _ in range(topology.world_size) + ] + for sharding_option in best_plan: + for shard in sharding_option.shards: + shard_storage = cast(Storage, shard.storage) + rank = cast(int, shard.rank) + used_hbm[rank] += shard_storage.hbm + used_ddr[rank] += shard_storage.ddr + perf[rank] += cast(Perf, shard.perf) + + used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in used_hbm] + used_ddr = [ddr + dense_storage.ddr + kjt_storage.ddr for ddr in used_ddr] + return used_hbm, used_ddr, perf + + +def _format_storage_breakdown(storage: Storage) -> str: + storage_hbm = round(bytes_to_gb(storage.hbm), 3) + storage_ddr = round(bytes_to_gb(storage.ddr), 3) + return f"({storage_hbm} GB, {storage_ddr} GB)" + + +def round_to_one_sigfig(x: float) -> str: + return f'{float(f"{x:.1g}"):g}' + + def _format_table(table: List[List[Union[str, int]]]) -> List[str]: longest_cols = [ (max([len(str(row[i])) for row in table]) + 3) for i in range(len(table[0])) @@ -495,3 +1039,100 @@ def _collapse_consecutive_ranks(ranks: List[int]) -> List[str]: return [f"{min(ranks)}-{max(ranks)}"] else: return [str(rank) for rank in ranks] + + +def _reduce_int_list(input_list: List[int]) -> str: + if len(input_list) == 0: + return "" + reduced = [] + count = 1 + prev_num = input_list[0] + + for num in input_list[1:]: + if num == prev_num: + count += 1 + else: + if count > 1: + reduced.append(f"{prev_num} * {count}") + else: + reduced.append(str(prev_num)) + prev_num = num + count = 1 + + # Handle the last number + if count > 1: + reduced.append(f"{prev_num}*{count}") + else: + reduced.append(str(prev_num)) + + return ", ".join(reduced) + + +def _calculate_critical_path(best_plan: List[ShardingOption]) -> CriticalPathEstimate: + """ + Calculates the critical path of the sharding plan. Makes the following assumptions: + + 1. There is a synchronization point across the ranks after each of the 4 events: Fwd/Bwd x Comms/Comp. + 2. There are additional synchronization points during communication (both fwd & bwd) for each module <> sharding type combination. + i. Communication operations for each shard from the same module <> sharding type group are executed sequentially. + ii. Ranks need to synchronize before they can begin the communication operation for the next module <> sharding type group. + 3. There are additional synchronization points during computation (both fwd & bwd) at the rank level. + i. Computation operations for each shard from the same module are executed sequentially. + ii. Ranks need to synchronize before they can begin the next set of events. + """ + comms_data = defaultdict(lambda: defaultdict(float)) + comp_data = defaultdict(lambda: defaultdict(float)) + for so in best_plan: + module = so.module + sharding_type = so.sharding_type + for shard in so.shards: + rank = cast(int, shard.rank) + perf = cast(Perf, shard.perf) + comms_data[(module, sharding_type, "fwd")][rank] += perf.fwd_comms + comms_data[(module, sharding_type, "bwd")][rank] += perf.bwd_comms + comp_data["fwd"][rank] += perf.fwd_compute + comp_data["bwd"][rank] += perf.bwd_compute + comms_rank_agg = { + outer_key: max(inner_dict.values()) + for outer_key, inner_dict in comms_data.items() + } + rank_count = len({cast(int, shard.rank) for so in best_plan for shard in so.shards}) + sharding_types = list({so.sharding_type for so in best_plan}) + adjustment_factor = 1 + # Default bandwidth is 12.5 is used and closer to 40 is right for internode GTT + if ( + rank_count > 8 + and len(sharding_types) == 1 + and sharding_types[0] == "column_wise" + ): + adjustment_factor = 3 + comms_estimate = sum(comms_rank_agg.values()) / adjustment_factor + comp_rank_agg = { + outer_key: max(inner_dict.values()) + for outer_key, inner_dict in comp_data.items() + } + comp_estimate = sum(comp_rank_agg.values()) + + return CriticalPathEstimate(comms_estimate, comp_estimate) + + +class NoopEmbeddingStats(Stats): + """ + Noop Stats for a sharding planner execution. + """ + + def log( + self, + sharding_plan: ShardingPlan, + topology: Topology, + batch_size: int, + storage_reservation: StorageReservation, + num_proposals: int, + num_plans: int, + run_time: float, + best_plan: List[ShardingOption], + constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + debug: bool = True, + ) -> None: + pass diff --git a/torchrec/distributed/planner/storage_reservations.py b/torchrec/distributed/planner/storage_reservations.py index e8e4ee71c..52909970d 100644 --- a/torchrec/distributed/planner/storage_reservations.py +++ b/torchrec/distributed/planner/storage_reservations.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy +import logging import math from typing import Dict, List, Optional, Set, Tuple @@ -13,46 +16,65 @@ from torchrec.distributed.planner.constants import BIGINT_DTYPE, POOLING_FACTOR from torchrec.distributed.planner.types import ( ParameterConstraints, + PlannerError, + PlannerErrorType, Storage, StorageReservation, Topology, ) -from torchrec.distributed.planner.utils import sharder_name -from torchrec.distributed.types import ModuleSharder +from torchrec.distributed.planner.utils import sharder_name, storage_repr_in_gb +from torchrec.distributed.types import get_tensor_size_bytes, ModuleSharder + +logger: logging.Logger = logging.getLogger(__name__) -def _get_module_size(module: nn.Module, multiplier: int) -> int: + +def _get_module_size(module: nn.Module, multiplier: float) -> int: parameters_size = sum( [ - multiplier * parameter.element_size() * parameter.nelement() + multiplier * get_tensor_size_bytes(parameter) for parameter in module.parameters() ] ) - buffers_size = sum( - [buffer.element_size() * buffer.nelement() for buffer in module.buffers()] - ) + buffers_size = sum([get_tensor_size_bytes(buffer) for buffer in module.buffers()]) - return parameters_size + buffers_size + return round(parameters_size + buffers_size) -def _reserve_dense_storage( - topology: Topology, +def _get_dense_tensor_size( module: nn.Module, shardable_modules: Set[nn.Module], - multiplier: int, -) -> Storage: - + multiplier: float = 6.0, +) -> int: dense_tensor_size = _get_module_size(module, multiplier) - sum( [ _get_module_size(shardable_module, multiplier) for shardable_module in shardable_modules ] ) + return dense_tensor_size + + +def _reserve_dense_storage( + topology: Topology, + module: nn.Module, + shardable_modules: Set[nn.Module], + multiplier: float, + dense_tensor_estimate: Optional[int] = None, +) -> Storage: + + dense_tensor_size = _get_dense_tensor_size(module, shardable_modules, multiplier) + if dense_tensor_estimate: + logger.info( + f"We override default dense tensor estimate ({dense_tensor_size} bytes) " + f"with user-provided dense tensor estimate ({dense_tensor_estimate} bytes)." + ) + dense_tensor_size = dense_tensor_estimate dense_tensor_storage = Storage( hbm=dense_tensor_size if topology.compute_device == "cuda" else 0, - ddr=dense_tensor_size if topology.compute_device == "cpu" else 0, + ddr=dense_tensor_size if topology.compute_device in {"cpu", "mtia"} else 0, ) for device in topology.devices: @@ -64,18 +86,15 @@ def _reserve_dense_storage( def _reserve_kjt_storage( topology: Topology, batch_size: int, - input_lengths: List[float], + batch_inputs: List[float], input_data_type_size: int, multiplier: int, ) -> Storage: - kjt_size = ( - math.ceil(float(batch_size) * sum(input_lengths) * float(input_data_type_size)) - * multiplier - ) + kjt_size = math.ceil(sum(batch_inputs) * float(input_data_type_size)) * multiplier kjt_storage = Storage( hbm=kjt_size if topology.compute_device == "cuda" else 0, - ddr=kjt_size if topology.compute_device == "cpu" else 0, + ddr=kjt_size if topology.compute_device in {"cpu", "mtia"} else 0, ) for device in topology.devices: @@ -89,15 +108,17 @@ def _reserve_storage_percentage(topology: Topology, percent: float) -> None: device.storage.hbm = int((1 - percent) * device.storage.hbm) -def _get_input_lengths_and_shardable_parameters( +def _get_batch_inputs_and_shardable_parameters( module: nn.Module, sharders: List[ModuleSharder[nn.Module]], + batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, ) -> Tuple[List[float], Set[nn.Module]]: sharder_map: Dict[str, ModuleSharder[nn.Module]] = { sharder_name(sharder.module_type): sharder for sharder in sharders } input_lengths: List[float] = [] + batch_sizes: List[int] = [] shardable_modules: Set[nn.Module] = set() def populate_shardable_modules( @@ -113,21 +134,32 @@ def populate_shardable_modules( names = sharder.shardable_parameters(module).keys() shardable_modules.add(module) - input_lengths.extend( - [ - sum(constraints[name].pooling_factors) + for name in names: + pooling_factors = ( + constraints[name].pooling_factors if constraints and constraints.get(name) - else POOLING_FACTOR - for name in names - ] - ) + else [POOLING_FACTOR] + ) + input_lengths.extend(pooling_factors) + batch_sizes.extend( + constraints[name].batch_sizes # pyre-ignore[6] + if constraints + and constraints.get(name) + and constraints[name].batch_sizes + else [batch_size] * len(pooling_factors) + ) populate_shardable_modules(module) - return input_lengths, shardable_modules + batch_inputs: List[float] = [ + input_length * batch_size + for input_length, batch_size in zip(input_lengths, batch_sizes) + ] + + return batch_inputs, shardable_modules -class FixedPercentageReservation(StorageReservation): +class FixedPercentageStorageReservation(StorageReservation): def __init__(self, percentage: float) -> None: assert percentage >= 0 and percentage <= 1 self._percentage: float = percentage @@ -148,20 +180,29 @@ def reserve( class HeuristicalStorageReservation(StorageReservation): """ Reserves storage for model to be sharded with heuristical calculation. The storage - reservation is comprised of nonsharded tensor storage, KJT storage, and an extra - percentage. + reservation is comprised of dense tensor storage, KJT storage, and an extra + percentage of total storage. Args: - percentage (float): extra storage percentage to reserve that acts as a margin of + percentage (float): extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage. + parameter_multiplier (float): heuristic multiplier for total parameter storage. + dense_tensor_estimate (Optional[int]): storage estimate for dense tensors, uses + default heuristic estimate if not provided. """ def __init__( self, percentage: float, + # heuristic: 6 * dense parameter size + # parameter + optimizer (~2x parameter) + ddp (~3x parameter) + parameter_multiplier: float = 6.0, + dense_tensor_estimate: Optional[int] = None, ) -> None: assert percentage >= 0 and percentage <= 1 self._percentage: float = percentage + self._parameter_multiplier = parameter_multiplier + self._dense_tensor_estimate = dense_tensor_estimate self._dense_storage: Optional[Storage] = None self._kjt_storage: Optional[Storage] = None @@ -176,8 +217,8 @@ def reserve( ) -> Topology: reserved_topology = copy.deepcopy(topology) - input_lengths, shardable_modules = _get_input_lengths_and_shardable_parameters( - module, sharders, constraints + batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters( + module, sharders, batch_size, constraints ) _reserve_storage_percentage(reserved_topology, self._percentage) @@ -186,40 +227,67 @@ def reserve( topology=reserved_topology, module=module, shardable_modules=shardable_modules, - # heuristic: 6 * dense parameter size - # parameter + optimizer (~2x parameter) + ddp (~3x parameter) - multiplier=6, + multiplier=self._parameter_multiplier, + dense_tensor_estimate=self._dense_tensor_estimate, ) self._kjt_storage = _reserve_kjt_storage( topology=reserved_topology, batch_size=batch_size, - input_lengths=input_lengths, + batch_inputs=batch_inputs, input_data_type_size=BIGINT_DTYPE, # 2 pipelined batches each with 10 internal copies multiplier=20, ) + if reserved_topology.devices[0].storage.hbm < 0: + negative_storage_solution = ( + f"The reserved topology ({storage_repr_in_gb(reserved_topology.devices[0].storage)}) " + "has negative available hbm storage, " + "after taking into account of the reserved hbm percentage, " + "the storage for dense modules, and the kjt storages. Hence " + "it is not possible to find a valid sharding plan. " + "\nPossible solutions:" + "\n 1) If FSDP is used, consider switching to FixedPercentageStorageReservation, since " + f"HeuristicalStorageReservation would not be able to calculate the " + f"dense storage ({storage_repr_in_gb(self._dense_storage)}) correctly. " + f"\n 2) Reduce local batch size ({batch_size}), which can help " + f"reduce the per rank kjt storage ({storage_repr_in_gb(self._kjt_storage)}). " + f"\n 3) Decrease the reserved hbm percentage ({self._percentage}). " + "\n 4) Use hardware with a higher hbm cap (current hardware has " + f"{storage_repr_in_gb(topology.devices[0].storage)} per rank). " + ) + raise PlannerError( + error_type=PlannerErrorType.INSUFFICIENT_STORAGE, + message=negative_storage_solution, + ) + return reserved_topology class InferenceStorageReservation(StorageReservation): """ - Reserves storage for model to be sharded for inference. The storage - reservation is comprised of nonsharded tensor storage, KJT storage, and an extra - percentage. + Reserves storage for model to be sharded for inference. The storage reservation + is comprised of dense tensor storage, KJT storage, and an extra percentage of total + storage. Note that when estimating for storage, dense modules are assumed to be on + GPUs and replicated across ranks. If this is not the case, please override the + estimates with dense_tensor_estimate. Args: percentage (float): extra storage percentage to reserve that acts as a margin of error beyond storage calculation. + dense_tensor_estimate (Optional[int]): storage estimate for dense tensors, use + default heuristic estimate if not provided. """ def __init__( self, percentage: float, + dense_tensor_estimate: Optional[int] = None, ) -> None: assert percentage >= 0 and percentage <= 1 self._percentage: float = percentage + self._dense_tensor_estimate = dense_tensor_estimate self._dense_storage: Optional[Storage] = None self._kjt_storage: Optional[Storage] = None @@ -234,8 +302,8 @@ def reserve( ) -> Topology: reserved_topology = copy.deepcopy(topology) - input_lengths, shardable_modules = _get_input_lengths_and_shardable_parameters( - module, sharders, constraints + batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters( + module, sharders, batch_size, constraints ) _reserve_storage_percentage(reserved_topology, self._percentage) @@ -245,12 +313,13 @@ def reserve( module=module, shardable_modules=shardable_modules, multiplier=1, + dense_tensor_estimate=self._dense_tensor_estimate, ) self._kjt_storage = _reserve_kjt_storage( topology=reserved_topology, batch_size=batch_size, - input_lengths=input_lengths, + batch_inputs=batch_inputs, input_data_type_size=BIGINT_DTYPE, multiplier=1, ) diff --git a/torchrec/distributed/planner/tests/benchmark.py b/torchrec/distributed/planner/tests/benchmark.py new file mode 100644 index 000000000..082cc4783 --- /dev/null +++ b/torchrec/distributed/planner/tests/benchmark.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""Stress tests for planner to find problematic scaling behavior.""" + +import time +import unittest + +from typing import List, Tuple + +from torch import nn + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.constants import BATCH_SIZE +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TestEnumeratorBenchmark(unittest.TestCase): + @staticmethod + def build( + world_size: int, num_tables: int + ) -> Tuple[EmbeddingEnumerator, nn.Module]: + compute_device = "cuda" + topology = Topology( + world_size=world_size, local_world_size=8, compute_device=compute_device + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=128, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_tables) + ] + model = TestSparseNN(tables=tables, weighted_tables=[]) + enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE) + return enumerator, model + + def measure(self, world_size: int, num_tables: int) -> float: + enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables) + + start_time = time.time() + sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()]) + end_time = time.time() + + self.assertEqual(len(sharding_options), num_tables) + return end_time - start_time + + def test_benchmark(self) -> None: + tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]] + print("\nEnumerator benchmark:") + for world_size, num_tables in tests: + t = self.measure(world_size, num_tables) + print( + f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s" + ) + + +def main() -> None: + unittest.main() + + +# This is structured as a unitttest like file so you can use its built-in command +# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator" +if __name__ == "__main__": + main() # pragma: no cover diff --git a/torchrec/distributed/planner/tests/test_constants.py b/torchrec/distributed/planner/tests/test_constants.py new file mode 100644 index 000000000..a3799e290 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_constants.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import List, Optional + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.constants import ( + DDR_MEM_BW, + HBM_MEM_BW, + HBM_TO_DDR_MEM_BW, + kernel_bw_lookup, +) + + +class TestKernelBWLookup(unittest.TestCase): + def test_uvm_caching_bw(self) -> None: + compute_device: str = "cuda" + computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value + + caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1] + + uvm_caching_bw: list[Optional[float]] = [ + kernel_bw_lookup( + compute_device, + computer_kernel, + HBM_MEM_BW, + DDR_MEM_BW, + HBM_TO_DDR_MEM_BW, + caching_ratio, + ) + for caching_ratio in caching_ratios + ] + expected_uvm_caching_bw: List[float] = [ + 3435973.8368, + 26655640.7808, + 49875307.724800006, + 73094974.6688, + 96314641.6128, + ] + self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw) + + def test_uvm_caching_bw_with_prefetch_pipeline(self) -> None: + compute_device: str = "cuda" + computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value + prefetch_pipeline: bool = True + + caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1] + + uvm_caching_bw: list[Optional[float]] = [ + kernel_bw_lookup( + compute_device, + computer_kernel, + HBM_MEM_BW, + DDR_MEM_BW, + HBM_TO_DDR_MEM_BW, + caching_ratio, + prefetch_pipeline, + ) + for caching_ratio in caching_ratios + ] + expected_uvm_caching_bw: List[float] = [ + 963146416.128, + 963146416.128, + 963146416.128, + 963146416.128, + 963146416.128, + ] + + self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw) diff --git a/torchrec/distributed/planner/tests/test_embedding_utils.py b/torchrec/distributed/planner/tests/test_embedding_utils.py new file mode 100644 index 000000000..3f793a21a --- /dev/null +++ b/torchrec/distributed/planner/tests/test_embedding_utils.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest + +from torchrec.distributed.embedding import ( + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, +) +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection + + +class CreateShardingInfoTest(unittest.TestCase): + def setUp(self) -> None: + self.tables = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=4, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=4, + ), + ] + + self.constraints = { + "table_0": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LRU, + load_factor=0.1, + reserved_memory=8.0, + precision=DataType.FP16, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.IGNORE, + ), + "table_1": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.2, + reserved_memory=0.0, + precision=DataType.FP16, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.NONE, + ), + } + + self.model = EmbeddingCollection(tables=self.tables) + self.sharder = EmbeddingCollectionSharder() + planner = EmbeddingShardingPlanner( + topology=Topology(world_size=1, compute_device="cpu"), + constraints=self.constraints, + ) + self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6] + + self.expected_sharding_infos = ( + ShardedEmbeddingCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), # pyre-ignore[6] + fused_params=None, + ) + ) + + def test_create_sharding_infos_by_sharding_override(self) -> None: + """ + Test that fused_params from sharders get overridden. + """ + + # with sharder fused params that will get overridden + sharder_fused_params = {"enforce_hbm": False} + overriden_sharding_infos = ( + ShardedEmbeddingCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), + fused_params=sharder_fused_params, + ) + ) + for sharding_type, overriden_sharding_info in overriden_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, overriden_sharding_info): + self.assertEqual(a.fused_params, b.fused_params) + + # with sharder fused params that won't get overridden + sharder_fused_params = {"ABC": True} + not_overriden_sharding_infos = ( + ShardedEmbeddingCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), + fused_params=sharder_fused_params, + ) + ) + for ( + sharding_type, + not_overriden_sharding_info, + ) in not_overriden_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, not_overriden_sharding_info): + self.assertNotEqual(a.fused_params, b.fused_params) + + def test_create_sharding_infos_by_sharding_combine(self) -> None: + """ + Test that fused_params can get info from both sharder and constraints. + """ + + new_constraints = copy.deepcopy(self.constraints) + + # remove two fused_params from constraints + for _, parameter_constraints in new_constraints.items(): + parameter_constraints.enforce_hbm = None + parameter_constraints.stochastic_rounding = None + + new_planner = EmbeddingShardingPlanner( + topology=Topology(world_size=1, compute_device="cpu"), + constraints=new_constraints, + ) + new_plan = new_planner.plan(self.model, [self.sharder]) # pyre-ignore[6] + + # provide that two fused params from sharder + sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False} + + combined_sharding_infos = ( + ShardedEmbeddingCollection.create_grouped_sharding_infos( + self.model, + new_plan.get_plan_for_module(""), # pyre-ignore[6] + fused_params=sharder_fused_params, + ) + ) + + # directly assertion won't work, since sharding_infos also have parameter_sharding + for sharding_type, combined_sharding_info in combined_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, combined_sharding_info): + self.assertEqual(a.fused_params, b.fused_params) + + # provide that two fused params from sharder wrongly + sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True} + wrong_combined_sharding_infos = ( + ShardedEmbeddingCollection.create_grouped_sharding_infos( + self.model, + new_plan.get_plan_for_module(""), # pyre-ignore[6] + fused_params=sharder_fused_params, + ) + ) + for ( + sharding_type, + wrong_combined_sharding_info, + ) in wrong_combined_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, wrong_combined_sharding_info): + self.assertNotEqual(a.fused_params, b.fused_params) diff --git a/torchrec/distributed/planner/tests/test_embeddingbag_utils.py b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py new file mode 100644 index 000000000..c1dcded69 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest + +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class CreateShardingInfoTest(unittest.TestCase): + def setUp(self) -> None: + self.tables = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=4, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=4, + ), + ] + + self.constraints = { + "table_0": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LRU, + load_factor=0.1, + reserved_memory=8.0, + precision=DataType.FP16, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.IGNORE, + ), + "table_1": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.2, + reserved_memory=0.0, + precision=DataType.FP16, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.NONE, + ), + } + + self.model = EmbeddingBagCollection(tables=self.tables) + self.sharder = EmbeddingBagCollectionSharder() + planner = EmbeddingShardingPlanner( + topology=Topology(world_size=1, compute_device="cpu"), + constraints=self.constraints, + ) + self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6] + + self.expected_sharding_infos = ( + ShardedEmbeddingBagCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), # pyre-ignore[6] + prefix="embedding_bags.", + fused_params=None, + ) + ) + + def test_create_sharding_infos_by_group_override(self) -> None: + """ + Test that fused_params from sharders get overridden. + """ + + # with sharder fused params that will get overridden + sharder_fused_params = {"enforce_hbm": False} + overriden_sharding_infos = ( + ShardedEmbeddingBagCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), + prefix="embedding_bags.", + fused_params=sharder_fused_params, + ) + ) + for sharding_type, overriden_sharding_info in overriden_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, overriden_sharding_info): + self.assertEqual(a.fused_params, b.fused_params) + + # with sharder fused params that won't get overridden + sharder_fused_params = {"ABC": True} + not_overriden_sharding_infos = ( + ShardedEmbeddingBagCollection.create_grouped_sharding_infos( + self.model, + self.expected_plan.get_plan_for_module(""), + prefix="embedding_bags.", + fused_params=sharder_fused_params, + ) + ) + for ( + sharding_type, + not_overriden_sharding_info, + ) in not_overriden_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, not_overriden_sharding_info): + self.assertNotEqual(a.fused_params, b.fused_params) + + def test_create_sharding_infos_by_group_combine(self) -> None: + """ + Test that fused_params can get info from both sharder and constraints. + """ + + new_constraints = copy.deepcopy(self.constraints) + + # remove two fused_params from constraints + for _, parameter_constraints in new_constraints.items(): + parameter_constraints.enforce_hbm = None + parameter_constraints.stochastic_rounding = None + + new_planner = EmbeddingShardingPlanner( + topology=Topology(world_size=1, compute_device="cpu"), + constraints=new_constraints, + ) + new_plan = new_planner.plan(self.model, [self.sharder]) # pyre-ignore[6] + + # provide that two fused params from sharder + sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False} + + combined_sharding_infos = ( + ShardedEmbeddingBagCollection.create_grouped_sharding_infos( + self.model, + new_plan.get_plan_for_module(""), # pyre-ignore[6] + prefix="embedding_bags.", + fused_params=sharder_fused_params, + ) + ) + + # directly assertion won't work, since sharding_infos also have parameter_sharding + for sharding_type, combined_sharding_info in combined_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, combined_sharding_info): + self.assertEqual(a.fused_params, b.fused_params) + + # provide that two fused params from sharder wrongly + sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True} + wrong_combined_sharding_infos = ( + ShardedEmbeddingBagCollection.create_grouped_sharding_infos( + self.model, + new_plan.get_plan_for_module(""), # pyre-ignore[6] + prefix="embedding_bags.", + fused_params=sharder_fused_params, + ) + ) + for ( + sharding_type, + wrong_combined_sharding_info, + ) in wrong_combined_sharding_infos.items(): + expected_sharding_info = self.expected_sharding_infos[sharding_type] + for a, b in zip(expected_sharding_info, wrong_combined_sharding_info): + self.assertNotEqual(a.fused_params, b.fused_params) diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 178765743..0ef9141b4 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -5,9 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import math import unittest from typing import cast, List -from unittest.mock import patch +from unittest.mock import MagicMock, patch import torch from torchrec.distributed.embedding_tower_sharding import ( @@ -16,6 +19,9 @@ ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.mc_embeddingbag import ( + ManagedCollisionEmbeddingBagCollectionSharder, +) from torchrec.distributed.planner.constants import BIGINT_DTYPE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.shard_estimators import ( @@ -32,12 +38,11 @@ from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig - EXPECTED_RW_SHARD_SIZES = [ - [[13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [9, 10]], - [[14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [12, 20]], - [[15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30]], - [[17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [11, 40]], + [[13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [9, 20]], + [[14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [12, 40]], + [[15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60]], + [[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]], ] EXPECTED_RW_SHARD_OFFSETS = [ @@ -47,99 +52,105 @@ [[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]], ] + +def get_expected_cache_aux_size(rows: int) -> int: + # 0.2 is the hardcoded cache load factor assumed in this test + return math.ceil(rows * (4 + 0.2 * 16)) + + EXPECTED_RW_SHARD_STORAGE = [ [ - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84488, ddr=0), - Storage(hbm=84328, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166928, ddr=0), + Storage(hbm=166608, ddr=0), ], [ - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=511072, ddr=0), - Storage(hbm=510912, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003712, ddr=0), + Storage(hbm=1003392, ddr=0), ], [ - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), - Storage(hbm=513800, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), + Storage(hbm=1007120, ddr=0), ], [ - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1340064, ddr=0), - Storage(hbm=1339104, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2653504, ddr=0), + Storage(hbm=2651584, ddr=0), ], ] EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [ [ - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84072, ddr=520), - Storage(hbm=84040, ddr=360), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040), + Storage(hbm=166032 + get_expected_cache_aux_size(9), ddr=720), ], [ - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510176, ddr=1120), - Storage(hbm=510144, ddr=960), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240), + Storage(hbm=1001856 + get_expected_cache_aux_size(12), ddr=1920), ], [ - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), - Storage(hbm=512360, ddr=1800), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), + Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600), ], [ - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337888, ddr=2720), - Storage(hbm=1337696, ddr=1760), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440), + Storage(hbm=2648768 + get_expected_cache_aux_size(11), ddr=3520), ], ] EXPECTED_TWRW_SHARD_SIZES = [ - [[25, 10], [25, 10], [25, 10], [25, 10]], - [[28, 20], [28, 20], [28, 20], [26, 20]], - [[30, 30], [30, 30], [30, 30], [30, 30]], - [[33, 40], [33, 40], [33, 40], [31, 40]], + [[25, 20], [25, 20], [25, 20], [25, 20]], + [[28, 40], [28, 40], [28, 40], [26, 40]], + [[30, 60], [30, 60], [30, 60], [30, 60]], + [[33, 80], [33, 80], [33, 80], [31, 80]], ] EXPECTED_TWRW_SHARD_OFFSETS = [ @@ -151,58 +162,54 @@ EXPECTED_TWRW_SHARD_STORAGE = [ [ - Storage(hbm=87016, ddr=0), - Storage(hbm=87016, ddr=0), - Storage(hbm=87016, ddr=0), - Storage(hbm=87016, ddr=0), + Storage(hbm=169936, ddr=0), + Storage(hbm=169936, ddr=0), + Storage(hbm=169936, ddr=0), + Storage(hbm=169936, ddr=0), ], [ - Storage(hbm=530624, ddr=0), - Storage(hbm=530624, ddr=0), - Storage(hbm=530624, ddr=0), - Storage(hbm=530464, ddr=0), + Storage(hbm=1024384, ddr=0), + Storage(hbm=1024384, ddr=0), + Storage(hbm=1024384, ddr=0), + Storage(hbm=1024064, ddr=0), ], [ - Storage(hbm=536080, ddr=0), - Storage(hbm=536080, ddr=0), - Storage(hbm=536080, ddr=0), - Storage(hbm=536080, ddr=0), + Storage(hbm=1031200, ddr=0), + Storage(hbm=1031200, ddr=0), + Storage(hbm=1031200, ddr=0), + Storage(hbm=1031200, ddr=0), ], [ - Storage(hbm=1369248, ddr=0), - Storage(hbm=1369248, ddr=0), - Storage(hbm=1369248, ddr=0), - Storage(hbm=1368928, ddr=0), + Storage(hbm=2685248, ddr=0), + Storage(hbm=2685248, ddr=0), + Storage(hbm=2685248, ddr=0), + Storage(hbm=2684608, ddr=0), ], ] EXPECTED_CW_SHARD_SIZES = [ - [[100, 10]], - [[110, 8], [110, 12]], - [[120, 9], [120, 9], [120, 12]], - [[130, 12], [130, 12], [130, 16]], + [[100, 20]], + [[110, 20], [110, 20]], + [[120, 20], [120, 20], [120, 20]], + [[130, 40], [130, 40]], ] EXPECTED_CW_SHARD_OFFSETS = [ [[0, 0]], - [[0, 0], [0, 8]], - [[0, 0], [0, 9], [0, 18]], - [[0, 0], [0, 12], [0, 24]], + [[0, 0], [0, 20]], + [[0, 0], [0, 20], [0, 40]], + [[0, 0], [0, 40]], ] EXPECTED_CW_SHARD_STORAGE = [ - [Storage(hbm=102304, ddr=0)], - [Storage(hbm=347584, ddr=0), Storage(hbm=447648, ddr=0)], - [ - Storage(hbm=315616, ddr=0), - Storage(hbm=315616, ddr=0), - Storage(hbm=366208, ddr=0), - ], + [Storage(hbm=188224, ddr=0)], + [Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)], [ - Storage(hbm=612448, ddr=0), - Storage(hbm=612448, ddr=0), - Storage(hbm=745600, ddr=0), + Storage(hbm=501120, ddr=0), + Storage(hbm=501120, ddr=0), + Storage(hbm=501120, ddr=0), ], + [Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)], ] EXPECTED_TWCW_SHARD_SIZES: List[List[List[int]]] = EXPECTED_CW_SHARD_SIZES @@ -210,18 +217,14 @@ EXPECTED_TWCW_SHARD_OFFSETS: List[List[List[int]]] = EXPECTED_CW_SHARD_OFFSETS EXPECTED_TWCW_SHARD_STORAGE = [ - [Storage(hbm=102304, ddr=0)], - [Storage(hbm=347584, ddr=0), Storage(hbm=447648, ddr=0)], + [Storage(hbm=188224, ddr=0)], + [Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)], [ - Storage(hbm=315616, ddr=0), - Storage(hbm=315616, ddr=0), - Storage(hbm=366208, ddr=0), - ], - [ - Storage(hbm=612448, ddr=0), - Storage(hbm=612448, ddr=0), - Storage(hbm=745600, ddr=0), + Storage(hbm=501120, ddr=0), + Storage(hbm=501120, ddr=0), + Storage(hbm=501120, ddr=0), ], + [Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)], ] @@ -346,17 +349,19 @@ def setUp(self) -> None: self.local_world_size = 4 self.constraints = { "table_0": ParameterConstraints(min_partition=20), - "table_1": ParameterConstraints(min_partition=8, pooling_factors=[1, 3, 5]), - "table_2": ParameterConstraints(min_partition=9, pooling_factors=[8, 2]), + "table_1": ParameterConstraints( + min_partition=20, pooling_factors=[1, 3, 5] + ), + "table_2": ParameterConstraints(min_partition=20, pooling_factors=[8, 2]), "table_3": ParameterConstraints( - min_partition=12, pooling_factors=[2, 1, 3, 7] + min_partition=40, pooling_factors=[2, 1, 3, 7] ), } self.num_tables = 4 tables = [ EmbeddingBagConfig( num_embeddings=100 + i * 10, - embedding_dim=10 + i * 10, + embedding_dim=20 + i * 20, name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -443,7 +448,6 @@ def test_dp_sharding(self) -> None: expected_storage = [ Storage(hbm=storage_size, ddr=0) for storage_size in storage_sizes ] - self.assertEqual( [shard.storage for shard in sharding_option.shards], expected_storage ) @@ -656,6 +660,164 @@ def test_filtering(self) -> None: self.assertIn(sharding_option.compute_kernel, expected_compute_kernels) self.assertNotIn(sharding_option.compute_kernel, unexpected_compute_kernels) + def test_filter_sharding_types_ebc(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = EmbeddingBagCollectionSharder() + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.TABLE_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + }, + ) + + def test_filter_sharding_types_mch_ebc(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.ROW_WISE.value, + }, + ) + + def test_filter_sharding_types_mch_ebc_no_available(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + with self.assertWarns(Warning): + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual(allowed_sharding_types, []) + + def test_filter_compute_kernels_ebc(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = EmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type + ) + + self.assertEqual( + set(allowed_compute_kernels), + { + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + }, + ) + + def test_filter_compute_kernels_mch_ebc(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type + ) + + self.assertEqual( + set(allowed_compute_kernels), + { + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + }, + ) + + def test_filter_compute_kernels_mch_ebc_no_available(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value + with self.assertWarns(Warning): + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type + ) + + self.assertEqual(allowed_compute_kernels, []) + def test_tower_sharding(self) -> None: # five tables # tower_0: tables[2], tables[3] @@ -710,3 +872,51 @@ def test_tower_collection_sharding(self) -> None: def test_empty(self) -> None: sharding_options = self.enumerator.enumerate(self.model, sharders=[]) self.assertFalse(sharding_options) + + def test_throw_ex_no_sharding_option_for_table(self) -> None: + cw_constraint = ParameterConstraints( + sharding_types=[ + ShardingType.COLUMN_WISE.value, + ], + compute_kernels=[ + EmbeddingComputeKernel.FUSED.value, + ], + ) + + rw_constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_ROW_WISE.value, + ], + compute_kernels=[ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ], + ) + + constraints = { + "table_0": cw_constraint, + "table_1": rw_constraint, + "table_2": cw_constraint, + "table_3": cw_constraint, + } + + enumerator = EmbeddingEnumerator( + topology=Topology( + world_size=self.world_size, + compute_device=self.compute_device, + local_world_size=self.local_world_size, + ), + batch_size=self.batch_size, + constraints=constraints, + ) + + sharder = cast(ModuleSharder[torch.nn.Module], CWSharder()) + + with self.assertRaises(Exception) as context: + _ = enumerator.enumerate(self.model, [sharder]) + + self.assertEqual( + str(context.exception), + "No available sharding type and compute kernel combination after applying user provided constraints for table_1. " + "Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. " + "To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1", + ) diff --git a/torchrec/distributed/planner/tests/test_parallelized_planners.py b/torchrec/distributed/planner/tests/test_parallelized_planners.py deleted file mode 100644 index f6e8a643b..000000000 --- a/torchrec/distributed/planner/tests/test_parallelized_planners.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import cast, List - -import torch -from torch import nn -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder - -from torchrec.distributed.planner.parallelized_planners import ( - ParallelizedEmbeddingShardingPlanner, -) - -from torchrec.distributed.planner.types import PlannerError, Topology -from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig - - -class TWvsRWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [EmbeddingComputeKernel.FUSED.value] - - -class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ShardingType.TABLE_WISE.value] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [EmbeddingComputeKernel.FUSED.value] - - -class TestParallelizedEmbeddingShardingPlanner(unittest.TestCase): - def setUp(self) -> None: - compute_device = "cuda" - self.topology = Topology( - world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device - ) - self.planner = ParallelizedEmbeddingShardingPlanner(topology=self.topology) - - def test_tw_solution(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=100, - embedding_dim=64, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(4) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - sharding_plan = self.planner.plan(module=model, sharders=[TWvsRWSharder()]) - expected_ranks = [[0], [0], [1], [1]] - ranks = [ - cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() - ] - - self.assertEqual(sorted(expected_ranks), sorted(ranks)) - - def test_hidden_rw_solution(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=100, - embedding_dim=64, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(3) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - sharding_plan = self.planner.plan(module=model, sharders=[TWvsRWSharder()]) - expected_ranks = [[0], [0, 1], [1]] - ranks = [ - cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() - ] - - self.assertEqual(sorted(expected_ranks), sorted(ranks)) - - def test_never_fit(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=10000000, - embedding_dim=10000000, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(2) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - - with self.assertRaises(PlannerError): - self.planner.plan(module=model, sharders=[TWvsRWSharder()]) - - self.assertEqual(self.planner._num_proposals, 4) - - def test_fail_then_rerun(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=4096, - embedding_dim=128, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(1) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - - with self.assertRaises(PlannerError): - self.planner.plan(module=model, sharders=[TWSharder()]) - - sharding_plan = self.planner.plan(module=model, sharders=[TWvsRWSharder()]) - expected_ranks = [[0, 1]] - ranks = [ - cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() - ] - - self.assertEqual(sorted(expected_ranks), sorted(ranks)) diff --git a/torchrec/distributed/planner/tests/test_partitioners.py b/torchrec/distributed/planner/tests/test_partitioners.py index 7b1e380e6..8f46066da 100644 --- a/torchrec/distributed/planner/tests/test_partitioners.py +++ b/torchrec/distributed/planner/tests/test_partitioners.py @@ -5,23 +5,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import unittest from typing import cast, List +from unittest.mock import MagicMock from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.constants import BATCH_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator -from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner +from torchrec.distributed.planner.partitioners import ( + GreedyPerfPartitioner, + MemoryBalancedPartitioner, + OrderedDeviceHardware, +) from torchrec.distributed.planner.types import ( + DeviceHardware, ParameterConstraints, PartitionByType, + Perf, PlannerError, + Shard, + ShardingOption, Storage, Topology, ) +from torchrec.distributed.planner.utils import reset_shard_rank from torchrec.distributed.test_utils.test_model import TestSparseNN from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -84,7 +96,7 @@ def setUp(self) -> None: tables = [ EmbeddingBagConfig( num_embeddings=100 + i, - embedding_dim=10 + i, + embedding_dim=4 * (10 + i), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -103,7 +115,9 @@ def test_tw_balanced_perf_device(self) -> None: ) for sharding_option in sharding_options: - sharding_option.shards[0].perf = 100 + sharding_option.shards[0].perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) sharding_option.shards[0].storage = Storage(hbm=1000, ddr=1000) candidate_topology = copy.deepcopy(self.topology) @@ -127,8 +141,15 @@ def test_tw_balanced_perf_device(self) -> None: } self.assertEqual(expected_ranks, ranks) - self.assertEqual(solution_topology.devices[0].perf, 200) - self.assertEqual(solution_topology.devices[1].perf, 200) + expected_perf = Perf( + fwd_compute=80, + fwd_comms=60, + bwd_compute=40, + bwd_comms=20, + ) + + self.assertEqual(solution_topology.devices[0].perf, expected_perf) + self.assertEqual(solution_topology.devices[1].perf, expected_perf) self.assertEqual( solution_topology.devices[0].storage, @@ -139,13 +160,95 @@ def test_tw_balanced_perf_device(self) -> None: self.topology.devices[1].storage - Storage(2000, 2000), ) + def test_device_partition_heap_invariant(self) -> None: + """Validate that _device_partition maintains the minheap invariant.""" + + def assert_heap(heap: List[OrderedDeviceHardware]) -> None: + for i in range(len(heap)): + left_child = 2 * i + 1 + right_child = 2 * i + 2 + self.assertFalse(left_child < len(heap) and heap[i] > heap[left_child]) + self.assertFalse( + right_child < len(heap) and heap[i] > heap[right_child] + ) + + def device_heaps_equal( + heap1: List[OrderedDeviceHardware], heap2: List[OrderedDeviceHardware] + ) -> None: + # OrderedDeviceHardware 2-key is a partial-order (equally good items might + # permute), so we validate that each heap maintains its heap invariant and + # that device ids are identical between them. + assert_heap(heap1) + assert_heap(heap2) + self.assertEqual( + sorted([id(x.device) for x in heap1]), + sorted([id(x.device) for x in heap2]), + ) + # TODO(damian): with 3-key we have a full total ordering, so we can test + # equivalence with the simpler below. For now leaving in the more complex + # verification that works for both 2-key and 3-key, if we decide on 3-key we + # can delete the more complex equality test. + # self.assertEqual([id(x.device) for x in heap1], + # [id(x.device) for x in heap2]) + + def perf(x: float) -> Perf: + return Perf(fwd_compute=x, fwd_comms=0, bwd_compute=0, bwd_comms=0) + + def empty_devices() -> List[DeviceHardware]: + return [ + DeviceHardware( + rank=x, storage=Storage(hbm=1_000_000, ddr=0), perf=perf(0) + ) + for x in range(6) + ] + + shards = [ + Shard(storage=Storage(hbm=1000, ddr=0), perf=perf(1), size=[], offset=[]) + for _ in range(30) + ] + + sharding_option: ShardingOption = ShardingOption( + name=MagicMock(), + tensor=MagicMock(), + module=MagicMock(), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=MagicMock(), + partition_by=MagicMock(), + compute_kernel=MagicMock(), + shards=shards, + ) + local_world_size: int = 3 + + def validate(threshold: float) -> None: + devices = empty_devices() + minheap_devices = GreedyPerfPartitioner._establish_minheap( + devices, local_world_size + ) + + GreedyPerfPartitioner._device_partition( + sharding_option, minheap_devices, threshold + ) + + want_minheap_devices = GreedyPerfPartitioner._establish_minheap( + devices, local_world_size + ) + device_heaps_equal(minheap_devices, want_minheap_devices) + + validate(0) # force heapify + validate(1) # force incremental rebuild + def test_tw_unbalanced_perf_device(self) -> None: sharding_options = self.enumerator.enumerate( module=self.model, sharders=[TWSharder()] ) for i, sharding_option in enumerate(sharding_options): - perf = 100 if i > 0 else 300 + perf = ( + Perf(fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10) + if i > 0 + else Perf(fwd_compute=75, fwd_comms=75, bwd_compute=75, bwd_comms=75) + ) sharding_option.shards[0].perf = perf sharding_option.shards[0].storage = Storage(hbm=1000, ddr=1000) @@ -170,8 +273,13 @@ def test_tw_unbalanced_perf_device(self) -> None: } self.assertEqual(expected_ranks, ranks) - self.assertEqual(solution_topology.devices[0].perf, 300) - self.assertEqual(solution_topology.devices[1].perf, 300) + expected_perfs = [ + Perf(fwd_compute=75, fwd_comms=75, bwd_compute=75, bwd_comms=75), + Perf(fwd_compute=120, fwd_comms=90, bwd_compute=60, bwd_comms=30), + ] + + self.assertEqual(solution_topology.devices[0].perf, expected_perfs[0]) + self.assertEqual(solution_topology.devices[1].perf, expected_perfs[1]) self.assertEqual( solution_topology.devices[0].storage, @@ -189,7 +297,7 @@ def test_tw_balanced_perf_host(self) -> None: tables = [ EmbeddingBagConfig( num_embeddings=64, - embedding_dim=10 + i, + embedding_dim=4 * (10 + i), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -205,9 +313,10 @@ def test_tw_balanced_perf_host(self) -> None: module=self.model, sharders=[TWRWSharder()] ) for sharding_option in sharding_options: - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) shard.storage = Storage(hbm=1000, ddr=1000) sharding_option.partition_by = PartitionByType.HOST.value @@ -238,14 +347,20 @@ def test_tw_balanced_perf_host(self) -> None: # there are two shards allocated to each device self.topology.devices[i].storage - Storage(2000, 2000), ) - self.assertEqual(solution_topology.devices[i].perf, 200) + expected_perf = Perf( + fwd_compute=80, + fwd_comms=60, + bwd_compute=40, + bwd_comms=20, + ) + self.assertEqual(solution_topology.devices[i].perf, expected_perf) def test_rw_unbalanced_perf_uniform(self) -> None: self.topology = Topology(world_size=4, compute_device="cuda") tables = [ EmbeddingBagConfig( num_embeddings=64, - embedding_dim=10 + i, + embedding_dim=4 * (10 + i), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -261,9 +376,10 @@ def test_rw_unbalanced_perf_uniform(self) -> None: module=self.model, sharders=[RWSharder()] ) for sharding_option in sharding_options: - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=25, fwd_comms=25, bwd_compute=25, bwd_comms=25 + ) shard.storage = Storage(hbm=1000, ddr=1000) sharding_option.partition_by = PartitionByType.UNIFORM.value @@ -299,15 +415,15 @@ def test_twcw_unbalanced_perf_host(self) -> None: world_size=16, local_world_size=8, compute_device="cuda" ) constraints = { - "table_0": ParameterConstraints(min_partition=2), - "table_1": ParameterConstraints(min_partition=10), - "table_2": ParameterConstraints(min_partition=5), - "table_3": ParameterConstraints(min_partition=8), + "table_0": ParameterConstraints(min_partition=4 * 2), + "table_1": ParameterConstraints(min_partition=4 * 10), + "table_2": ParameterConstraints(min_partition=4 * 5), + "table_3": ParameterConstraints(min_partition=4 * 8), } tables = [ EmbeddingBagConfig( num_embeddings=64, - embedding_dim=20 * (i + 1), + embedding_dim=80 * (i + 1), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -324,9 +440,10 @@ def test_twcw_unbalanced_perf_host(self) -> None: sharders=[TWCWSharder()], ) for sharding_option in sharding_options: - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=25, fwd_comms=25, bwd_compute=25, bwd_comms=25 + ) shard.storage = Storage(hbm=1000, ddr=1000) sharding_option.partition_by = PartitionByType.HOST.value @@ -357,10 +474,12 @@ def test_twrw_and_twcw_perf_host(self) -> None: sharding_types=[ShardingType.TABLE_ROW_WISE.value] ), "table_1": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=4 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 8, ), "table_2": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=7 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 10, ), "table_3": ParameterConstraints( sharding_types=[ShardingType.TABLE_ROW_WISE.value] @@ -369,7 +488,7 @@ def test_twrw_and_twcw_perf_host(self) -> None: tables = [ EmbeddingBagConfig( num_embeddings=128, - embedding_dim=20 * (i + 1), + embedding_dim=80 * (i + 1), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -387,9 +506,10 @@ def test_twrw_and_twcw_perf_host(self) -> None: ) for sharding_option in sharding_options: - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=25, fwd_comms=25, bwd_compute=25, bwd_comms=25 + ) shard.storage = Storage(hbm=1000, ddr=1000) sharding_option.partition_by = PartitionByType.HOST.value @@ -398,10 +518,10 @@ def test_twrw_and_twcw_perf_host(self) -> None: storage_constraint=self.topology, ) expected_ranks = { - "table_0": [8, 9, 10, 11, 12, 13, 14, 15], - "table_1": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1], - "table_2": [8, 9, 10, 11, 12, 13, 14, 15], - "table_3": [0, 1, 2, 3, 4, 5, 6, 7], + "table_0": [0, 1, 2, 3, 4, 5, 6, 7], + "table_1": [8, 9, 10, 11, 12], + "table_2": [0, 1, 2, 3, 4, 5], + "table_3": [8, 9, 10, 11, 12, 13, 14, 15], } ranks = { @@ -420,10 +540,12 @@ def test_twrw_and_twcw_cohost(self) -> None: sharding_types=[ShardingType.TABLE_ROW_WISE.value] ), "table_1": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=4 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 8, ), "table_2": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=7 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 10, ), "table_3": ParameterConstraints( sharding_types=[ShardingType.TABLE_ROW_WISE.value] @@ -432,7 +554,7 @@ def test_twrw_and_twcw_cohost(self) -> None: tables = [ EmbeddingBagConfig( num_embeddings=128, - embedding_dim=20 * (i + 1), + embedding_dim=80 * (i + 1), name="table_" + str(i), feature_names=["feature_" + str(i)], ) @@ -450,9 +572,10 @@ def test_twrw_and_twcw_cohost(self) -> None: ) for i, sharding_option in enumerate(sharding_options): - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=25, fwd_comms=25, bwd_compute=25, bwd_comms=25 + ) shard.storage = Storage(hbm=1000, ddr=1000) sharding_option.partition_by = PartitionByType.HOST.value if i <= 2: @@ -464,8 +587,8 @@ def test_twrw_and_twcw_cohost(self) -> None: ) expected_ranks = { "table_0": [0, 1, 2, 3, 4, 5, 6, 7], - "table_1": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1], - "table_2": [2, 3, 4, 5, 6, 7, 0, 1], + "table_1": [0, 1, 2, 3, 4], + "table_2": [5, 6, 7, 0, 1, 2], "table_3": [8, 9, 10, 11, 12, 13, 14, 15], } @@ -480,12 +603,17 @@ def test_twrw_and_twcw_cohost(self) -> None: solution_topology = self.partitioner._topology for i in range(self.topology.world_size): total_storage = Storage(0, 0) - total_perf = 0 + total_perf = Perf( + fwd_compute=0, + fwd_comms=0, + bwd_compute=0, + bwd_comms=0, + ) for sharding_option in sharding_plan: for shard in sharding_option.shards: if shard.rank == i: total_storage += cast(Storage, shard.storage) - total_perf += shard.perf + total_perf += cast(Perf, shard.perf) self.assertEqual( solution_topology.devices[i].storage + total_storage, self.topology.devices[i].storage, @@ -501,10 +629,12 @@ def test_oom(self) -> None: sharding_types=[ShardingType.TABLE_ROW_WISE.value] ), "table_1": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=4 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 4, ), "table_2": ParameterConstraints( - sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], min_partition=7 + sharding_types=[ShardingType.TABLE_COLUMN_WISE.value], + min_partition=4 * 7, ), "table_3": ParameterConstraints( sharding_types=[ShardingType.TABLE_ROW_WISE.value] @@ -531,9 +661,10 @@ def test_oom(self) -> None: ) for i, sharding_option in enumerate(sharding_options): - perf = 100.0 for shard in sharding_option.shards: - shard.perf = perf + shard.perf = Perf( + fwd_compute=25, fwd_comms=25, bwd_compute=25, bwd_comms=25 + ) shard.storage = Storage( # pyre-ignore [6] hbm=self.topology.devices[0].storage.hbm / 2, @@ -549,3 +680,256 @@ def test_oom(self) -> None: proposal=sharding_options, storage_constraint=self.topology, ) + + +class TestMemoryBalancedPartitioner(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=4 * (10 + i), + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + self.topology = Topology( + world_size=2, + compute_device=compute_device, + hbm_cap=2000 * 1024**2, + ) + self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.enumerator = EmbeddingEnumerator( + topology=self.topology, batch_size=BATCH_SIZE + ) + self.greedy_perf_partitioner = GreedyPerfPartitioner() + self.memory_balanced_partitioner = MemoryBalancedPartitioner(tolerance=100) + + def test_same_sharding_plan(self) -> None: + sharding_options = self.enumerator.enumerate( + module=self.model, sharders=[TWSharder()] + ) + + for sharding_option in sharding_options: + sharding_option.shards[0].perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) + sharding_option.shards[0].storage = Storage( + hbm=1000 * 1024**2, ddr=1000 * 1024**2 + ) + + greedy_perf_sharding_plan = self.greedy_perf_partitioner.partition( + proposal=sharding_options, + storage_constraint=self.topology, + ) + greedy_perf_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in greedy_perf_sharding_plan + } + + reset_shard_rank(sharding_options) + memory_balanced_sharding_plan = self.memory_balanced_partitioner.partition( + proposal=sharding_options, + storage_constraint=self.topology, + ) + memory_balanced_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in memory_balanced_sharding_plan + } + self.assertEqual(greedy_perf_ranks, memory_balanced_ranks) + + def test_different_sharding_plan(self) -> None: + sharding_options = self.enumerator.enumerate( + module=self.model, sharders=[TWSharder()] + ) + + for i, sharding_option in enumerate(sharding_options): + sharding_option.shards[0].perf = Perf( + fwd_compute=40 * (i + 1), fwd_comms=0, bwd_compute=0, bwd_comms=0 + ) + sharding_option.shards[0].storage = Storage( + hbm=(1500 - i * 500) * 1024**2, ddr=1000 * 1024**2 + ) + + greedy_perf_sharding_plan = self.greedy_perf_partitioner.partition( + proposal=sharding_options, + storage_constraint=self.topology, + ) + greedy_perf_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in greedy_perf_sharding_plan + } + greedy_perf_expected_ranks = { + "table_0": [0], + "table_1": [1], + "table_2": [0], + } + self.assertEqual(greedy_perf_ranks, greedy_perf_expected_ranks) + + greedy_perf_hbm_uses = [0] * self.topology.world_size + for sharding_option in sharding_options: + for shard in sharding_option.shards: + if shard.storage and shard.rank is not None: + greedy_perf_hbm_uses[ + # pyre-fixme[6]: For 1st argument expected `SupportsIndex` + # but got `Optional[int]`. + shard.rank + ] += shard.storage.hbm # pyre-ignore[16] + + reset_shard_rank(sharding_options) + memory_balanced_sharding_plan = self.memory_balanced_partitioner.partition( + proposal=sharding_options, + storage_constraint=self.topology, + ) + memory_balanced_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in memory_balanced_sharding_plan + } + memory_balanced_expected_ranks = { + "table_0": [0], + "table_1": [1], + "table_2": [0], + } + self.assertEqual(memory_balanced_ranks, memory_balanced_expected_ranks) + + memory_balanced_hbm_uses = [0.0] * self.topology.world_size + for sharding_option in sharding_options: + for shard in sharding_option.shards: + if shard.storage and shard.rank: + # pyre-fixme[6]: For 1st argument expected `SupportsIndex` but + # got `Optional[int]`. + memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm + + self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses)) + + +class TestBalanceModules(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=4 * (10 + i), + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(1) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=200 + i, + embedding_dim=8 * (10 + i), + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(3) + ] + self.topology = Topology( + world_size=2, + compute_device=compute_device, + hbm_cap=2000 * 1024**2, + ) + self.model = TestSparseNN(tables=tables, weighted_tables=weighted_tables) + self.enumerator = EmbeddingEnumerator( + topology=self.topology, batch_size=BATCH_SIZE + ) + + self.sharding_options = self.enumerator.enumerate( + module=self.model, sharders=[TWSharder()] + ) + for sharding_option in self.sharding_options: + sharding_option.shards[0].perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) + sharding_option.shards[0].storage = Storage( + hbm=10 * 1024**2, ddr=1000 * 1024**2 + ) + + def test_greedy_partitioner(self) -> None: + greedy_partitioner = GreedyPerfPartitioner(balance_modules=False) + balance_modules_greedy_partitioner = GreedyPerfPartitioner(balance_modules=True) + + greedy_sharding_plan = greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + greedy_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in greedy_sharding_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = balance_modules_greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + greedy_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + self.assertEqual(greedy_expected_ranks, greedy_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks) + + def test_memory_balanced_partitioner(self) -> None: + memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=False + ) + balance_modules_memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=True + ) + + memory_balanced_plan = memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + memory_balanced_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in memory_balanced_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = ( + balance_modules_memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + memory_balanced_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + self.assertEqual(memory_balanced_expected_ranks, memory_balanced_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks) diff --git a/torchrec/distributed/planner/tests/test_perf_models.py b/torchrec/distributed/planner/tests/test_perf_models.py new file mode 100644 index 000000000..d290b6647 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_perf_models.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from unittest.mock import MagicMock + +from torchrec.distributed.planner.perf_models import NoopPerfModel, NoopStorageModel +from torchrec.distributed.planner.types import ( + Perf, + Shard, + ShardingOption, + Storage, + Topology, +) + + +class TestPerfModels(unittest.TestCase): + def setUp(self) -> None: + self.topology = Topology(world_size=2, compute_device="cuda") + self.tables = [ + ShardingOption( + name=MagicMock(), + tensor=MagicMock(), + module=MagicMock(), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=MagicMock(), + partition_by=MagicMock(), + compute_kernel=MagicMock(), + shards=[ + Shard( + size=MagicMock(), + offset=MagicMock(), + rank=rank, + perf=Perf( + fwd_compute=2 - rank, + fwd_comms=0, + bwd_compute=0, + bwd_comms=0, + ), + storage=Storage(hbm=100 * (rank + 1), ddr=0), + ), + ], + ) + for rank in range(2) + ] + + def test_noop_perf_model(self) -> None: + perf_model = NoopPerfModel(self.topology) + perf_rating = perf_model.rate(self.tables) + self.assertEqual(perf_rating, 2) + + def test_noop_storage_model(self) -> None: + perf_model = NoopStorageModel(self.topology) + perf_rating = perf_model.rate(self.tables) + self.assertEqual(perf_rating, 200) diff --git a/torchrec/distributed/planner/tests/test_planners.py b/torchrec/distributed/planner/tests/test_planners.py index acb76fcef..abab40ac3 100644 --- a/torchrec/distributed/planner/tests/test_planners.py +++ b/torchrec/distributed/planner/tests/test_planners.py @@ -5,17 +5,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from typing import cast, List +from typing import cast, List, Optional import torch from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.planner.planners import EmbeddingShardingPlanner -from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology +from torchrec.distributed.planner.proposers import EmbeddingOffloadScaleupProposer +from torchrec.distributed.planner.types import ( + PlannerError, + PlannerErrorType, + ShardingOption, + Topology, +) +from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingPlan, ShardingType +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, + EmbeddingModuleShardingPlan, + ModuleSharder, + ShardingPlan, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -62,7 +81,9 @@ def test_tw_solution(self) -> None: expected_ranks = [[0], [0], [1], [1]] ranks = [ cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() + for param_shard in cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + ).values() ] self.assertEqual(sorted(expected_ranks), sorted(ranks)) @@ -82,7 +103,9 @@ def test_hidden_rw_solution(self) -> None: expected_ranks = [[0], [0, 1], [1]] ranks = [ cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() + for param_shard in cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + ).values() ] self.assertEqual(sorted(expected_ranks), sorted(ranks)) @@ -105,7 +128,8 @@ def test_never_fit(self) -> None: context.exception.error_type, PlannerErrorType.INSUFFICIENT_STORAGE ) - self.assertEqual(self.planner._num_proposals, 4) + # since it has negative storage_constraint + self.assertEqual(self.planner._num_proposals, 0) def test_fail_then_rerun(self) -> None: tables = [ @@ -129,7 +153,9 @@ def test_fail_then_rerun(self) -> None: expected_ranks = [[0, 1]] ranks = [ cast(List[int], param_shard.ranks) - for param_shard in sharding_plan.plan["sparse.ebc"].values() + for param_shard in cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + ).values() ] self.assertEqual(sorted(expected_ranks), sorted(ranks)) @@ -148,3 +174,188 @@ def test_no_sharders(self) -> None: sharding_plan = self.planner.plan(module=model, sharders=[]) self.assertEqual(sharding_plan, ShardingPlan({})) + + +class TestEmbeddingShardingPlannerWithConstraints(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device + ) + self.tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.constraints = { + "table_0": ParameterConstraints( + enforce_hbm=True, + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + ), + feature_names=self.tables[0].feature_names, + ), + "table_1": ParameterConstraints( + enforce_hbm=False, + stochastic_rounding=True, + feature_names=self.tables[1].feature_names, + ), + "table_2": ParameterConstraints( + bounds_check_mode=BoundsCheckMode.FATAL, + feature_names=self.tables[2].feature_names, + ), + "table_3": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.1, + reserved_memory=1.0, + precision=DataType.FP16, + ), + feature_names=self.tables[3].feature_names, + ), + } + self.planner = EmbeddingShardingPlanner( + topology=self.topology, constraints=self.constraints + ) + + def test_fused_paramters_from_constraints(self) -> None: + model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta")) + sharding_plan = self.planner.plan(module=model, sharders=get_default_sharders()) + + expected_fused_params = { + "table_0": ( + CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=None, + reserved_memory=None, + precision=None, + ), + True, + None, + None, + ), + "table_1": (None, False, True, None), + "table_2": (None, None, None, BoundsCheckMode.FATAL), + "table_3": ( + CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.1, + reserved_memory=1.0, + precision=DataType.FP16, + ), + None, + None, + None, + ), + } + + table_names = ["table_" + str(i) for i in range(4)] + for table in table_names: + parameter_sharding = cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + )[table] + self.assertEqual( + ( + parameter_sharding.cache_params, + parameter_sharding.enforce_hbm, + parameter_sharding.stochastic_rounding, + parameter_sharding.bounds_check_mode, + ), + expected_fused_params[table], + ) + + def test_passing_info_through_constraints(self) -> None: + model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta")) + _ = self.planner.plan(module=model, sharders=get_default_sharders()) + + best_plan: Optional[List[ShardingOption]] = self.planner._best_plan + self.assertIsNotNone(best_plan) + + for table, constraint, sharding_option in zip( + self.tables, self.constraints.values(), best_plan + ): + self.assertEqual(table.name, sharding_option.name) + + self.assertEqual(table.feature_names, sharding_option.feature_names) + self.assertEqual(table.feature_names, constraint.feature_names) + + self.assertEqual(constraint.cache_params, sharding_option.cache_params) + self.assertEqual(constraint.enforce_hbm, sharding_option.enforce_hbm) + self.assertEqual( + constraint.stochastic_rounding, sharding_option.stochastic_rounding + ) + self.assertEqual( + constraint.bounds_check_mode, sharding_option.bounds_check_mode + ) + self.assertEqual(constraint.is_weighted, sharding_option.is_weighted) + + +class AutoSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [ + k.value + for k in EmbeddingComputeKernel + if k is not EmbeddingComputeKernel.CUSTOMIZED_KERNEL + ] + + +class TestAutoPlannerWithScaleupProposer(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, + hbm_cap=1024 * 1024 * 2, + compute_device=compute_device, + ) + self.tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.constraints = { + f"table_{i}": ParameterConstraints( + # Just needs to be non-None for ScaleupProposer to work. + cache_params=CacheParams(algorithm=CacheAlgorithm.LRU), + ) + for i in range(4) + } + self.planner = EmbeddingShardingPlanner( + topology=self.topology, + proposer=EmbeddingOffloadScaleupProposer(), + constraints=self.constraints, + ) + + def test_auto_sharder_solution(self) -> None: + model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta")) + sharding_plan = self.planner.plan(module=model, sharders=[AutoSharder()]) + expected_ranks = [[0, 1], [0, 1], [0, 1], [0, 1]] + ranks = [ + cast(List[int], param_shard.ranks) + for param_shard in cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + ).values() + ] + compute_kernels = { + param_shard.compute_kernel + for param_shard in cast( + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ebc"] + ).values() + } + + self.assertEqual(sorted(expected_ranks), sorted(ranks)) + self.assertSetEqual( + {EmbeddingComputeKernel.FUSED_UVM_CACHING.value}, compute_kernels + ) diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index f5d7c562f..c048f63e5 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -5,23 +5,48 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from typing import cast, List, Optional +from typing import cast, List, Optional, Type from unittest.mock import MagicMock import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.constants import BATCH_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.proposers import ( + DynamicProgrammingProposer, + EmbeddingOffloadScaleupProposer, GreedyProposer, GridSearchProposer, proposers_to_proposals_list, UniformProposer, ) -from torchrec.distributed.planner.types import Proposer, ShardingOption, Topology +from torchrec.distributed.planner.shard_estimators import ( + _calculate_storage_specific_sizes, + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.planner.types import ( + Enumerator, + ParameterConstraints, + Perf, + Proposer, + Shard, + ShardingOption, + Storage, + Topology, +) from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import ( + CacheAlgorithm, + CacheParams, + CacheStatistics, + ModuleSharder, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -29,6 +54,7 @@ class MockProposer(Proposer): def load( self, search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, ) -> None: pass @@ -37,6 +63,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: pass @@ -44,6 +71,84 @@ def propose(self) -> Optional[List[ShardingOption]]: pass +class MockCacheStatistics(CacheStatistics): + def __init__(self, expected_lookups: int, cacheability: float) -> None: + self._expected_lookups = expected_lookups + self._cacheability = cacheability + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + return clf + + @property + def cacheability(self) -> float: + return self._cacheability + + +# Mocking _calculate_storage_specific_sizes to skip cache aux state accounting for +# simpler testing +def mock_calculate_storage_specific_sizes( + storage: int, + shape: torch.Size, + shard_sizes: List[List[int]], + sharding_type: str, + optimizer_class: Optional[Type[torch.optim.Optimizer]] = None, + is_inference: bool = False, + clf: Optional[float] = None, +) -> List[int]: + return _calculate_storage_specific_sizes( + storage, shape, shard_sizes, sharding_type, optimizer_class, is_inference, None + ) + + +def make_sharding_option( + name: str, + raw_size: int, + clf: Optional[float], + perf: Optional[Perf] = None, +) -> ShardingOption: + """ + Convenience factory method for creating a sharding option with a single shard. + """ + return ShardingOption( + name=name, + tensor=torch.zeros(1), + # pyre-ignore + module=("model", None), + input_lengths=[], + batch_size=8, + sharding_type="row_wise", + partition_by="DEVICE", + compute_kernel="fused" if clf is None else "fused_uvm_caching", + shards=[ + Shard( + size=[1, raw_size], + offset=[0, 0], + perf=( + Perf( + fwd_compute=0, + fwd_comms=0, + bwd_compute=0, + bwd_comms=0, + ) + if perf is None + else perf + ), + ) + ], + cache_params=( + None + if clf is None + else CacheParams( + load_factor=clf, + ) + ), + ) + + class TestProposers(unittest.TestCase): def setUp(self) -> None: topology = Topology(world_size=2, compute_device="cuda") @@ -51,6 +156,9 @@ def setUp(self) -> None: self.greedy_proposer = GreedyProposer() self.uniform_proposer = UniformProposer() self.grid_search_proposer = GridSearchProposer() + self.dynamic_programming_proposer = DynamicProgrammingProposer() + self._sharding_types = [x.value for x in ShardingType] + self.maxDiff = None def test_greedy_two_table(self) -> None: tables = [ @@ -67,6 +175,17 @@ def test_greedy_two_table(self) -> None: feature_names=["feature_1"], ), ] + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present. This means + the greedy proposer will have a different order of sharding types on each test invocation + which we cannot have a harcoded "correct" answer for. We mock the call to _filter_sharding_types + to ensure the order of the sharding types list is always the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) search_space = self.enumerator.enumerate( @@ -83,7 +202,7 @@ def test_greedy_two_table(self) -> None: proposal = cast(List[ShardingOption], self.greedy_proposer.propose()) proposal.sort( key=lambda sharding_option: ( - max([shard.perf for shard in sharding_option.shards]), + max([shard.perf.total for shard in sharding_option.shards]), sharding_option.name, ) ) @@ -109,16 +228,16 @@ def test_greedy_two_table(self) -> None: ("table_1", "row_wise", "fused"), ], [ + ("table_0", "grid_shard", "fused"), ("table_1", "row_wise", "fused"), - ("table_0", "data_parallel", "dense"), ], [ - ("table_1", "table_row_wise", "fused"), + ("table_1", "row_wise", "fused"), ("table_0", "data_parallel", "dense"), ], [ + ("table_1", "table_row_wise", "fused"), ("table_0", "data_parallel", "dense"), - ("table_1", "data_parallel", "dense"), ], ] @@ -174,7 +293,7 @@ def test_uniform_three_table(self) -> None: while proposal: proposal.sort( key=lambda sharding_option: ( - max([shard.perf for shard in sharding_option.shards]), + max([shard.perf.total for shard in sharding_option.shards]), sharding_option.name, ) ) @@ -275,6 +394,16 @@ def test_grid_search_three_table(self) -> None: for i in range(1, 4) ] model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) search_space = self.enumerator.enumerate( module=model, sharders=[ @@ -289,7 +418,7 @@ def test_grid_search_three_table(self) -> None: - fused_uvm DP will have 1 possible compute kernel: dense So the total number of pruned options will be: - (num_sharding_types - 1) * 3 + 1 = 16 + (num_sharding_types - 1) * 3 + 1 = 19 """ num_pruned_options = (len(ShardingType) - 1) * 3 + 1 self.grid_search_proposer.load(search_space) @@ -309,6 +438,482 @@ def test_grid_search_three_table(self) -> None: self.assertEqual(num_pruned_options ** len(tables), num_proposals) + def test_dynamic_programming_three_table(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100 * i, + embedding_dim=10 * i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(1, 4) + ] + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + search_space = self.enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + + self.dynamic_programming_proposer.load(search_space) + + num_proposals = 0 + proposal = self.dynamic_programming_proposer.propose() + GB = 1024 * 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB + ) + while proposal: + self.dynamic_programming_proposer.feedback( + partitionable=True, storage_constraint=storage_constraint + ) + proposal = self.dynamic_programming_proposer.propose() + num_proposals += 1 + self.assertEqual(2, num_proposals) + + def test_get_scalable_sharding_options(self) -> None: + def make_so( + name: str, clf: Optional[float], stats: Optional[CacheStatistics] + ) -> ShardingOption: + so = make_sharding_option(name, 1, clf) + if clf: + assert so.cache_params + so.cache_params.stats = stats + return so + + proposal = [ + make_so("fused", None, None), + make_so("caching-no-stats", 0.5, None), + make_so( + "caching-stats", + 0.5, + MockCacheStatistics(expected_lookups=1, cacheability=0.42), + ), + make_so( + "caching-stats-no-data", + 0, + MockCacheStatistics(expected_lookups=0, cacheability=0), + ), + ] + got = EmbeddingOffloadScaleupProposer.get_scalable_sharding_options(proposal) + want = [proposal[-2]] + self.assertEqual(got, want) + + def test_allocate_budget(self) -> None: + model = torch.tensor([[1.0, 0.0], [2.0, 3.0], [4.0, 5.0]]) + got = EmbeddingOffloadScaleupProposer.clf_to_bytes( + model, torch.tensor([0, 0.5, 1]) + ) + torch.testing.assert_close(got, torch.tensor([0, 4, 9], dtype=torch.float64)) + + # Scenario 1, enough budget to scale everything to 1.0 + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 100_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([2, 2, 2]), + ) + torch.testing.assert_close(got, torch.tensor([1.0, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ).item() + self.assertLess(increase, budget) + + # Scenario 2, limited budget, uniform scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([2, 2, 2]), + ) + torch.testing.assert_close(got, torch.tensor([0.26667, 0.26667, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 3, limited budget, skewed scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([2, 4, 2]), + ) + # increase is twice as much for table 2 (started at 0.1) + torch.testing.assert_close( + got, torch.tensor([0.1 + 0.11111, 0.1 + 2 * 0.11111, 1.0]) + ) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(int(increase), budget) + + # Scenario 4, multi-pass scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + fused_hbm_ceiling = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) + mins = torch.tensor([0.1, 0.3, 0.5]) + budget = 50_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([1, 2, 100]), + ) + torch.testing.assert_close(got, torch.tensor([0.56667, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 5, prefetch overhead causing early promotion + # like scenario 4, but we set fused size to 80%, which saves enough memory + # to promote all 3 to HBM inside the same budget. + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + fused_hbm_ceiling = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1.0) * 0.8 + ) + mins = torch.tensor([0.1, 0.3, 0.5]) + budget = 50_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + fused_hbm_ceiling=fused_hbm_ceiling, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([1, 2, 100]), + ) + torch.testing.assert_close(got, torch.tensor([1.0, 1.0, 1.0])) + self.assertLessEqual( + fused_hbm_ceiling.sum().item(), + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum().item() + + budget, + ) + + @unittest.mock.patch( + "torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes", + side_effect=mock_calculate_storage_specific_sizes, + ) + def test_scaleup(self, _) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10000, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(4) + ] + + # Place first three tables into cache, 4th table leave on hbm. table_1 has a + # larger cacheability score so budget should be skewed to scaling table_1 more + # than table_0. table_2 is a deprecated feature we have no stats for (so + # expected_lookups 0), we want to see that left at its original load factor, + # i.e. doesn't participate in scaleup. + constraints = { + "table_0": ParameterConstraints( + sharding_types=[ShardingType.COLUMN_WISE.value], + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + "table_1": ParameterConstraints( + sharding_types=[ShardingType.COLUMN_WISE.value], + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5), + ), + ), + "table_2": ParameterConstraints( + sharding_types=[ShardingType.COLUMN_WISE.value], + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.002, + stats=MockCacheStatistics(expected_lookups=0, cacheability=0.5), + ), + ), + "table_3": ParameterConstraints( + sharding_types=[ShardingType.COLUMN_WISE.value], + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + cache_params=CacheParams(), + ), + } + + GB = 1024 * 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB + ) + # Ignoring table_2, the remainder require 224GB if all placed on HBM. We only + # have 200GB, so we can't promote both uvm tables. Initial plan needs uses 90GB, + # with 110GB of available budget. + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, + batch_size=BATCH_SIZE, + constraints=constraints, + estimator=[ + EmbeddingPerfEstimator( + topology=storage_constraint, constraints=constraints + ), + EmbeddingStorageEstimator( + topology=storage_constraint, + constraints=constraints, + ), + ], + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + proposal = proposer.propose() + best_plan = None + best_perf = 1e99 + proposals = -1 + while proposal is not None: + proposals += 1 + mem = sum(so.total_storage.hbm for so in proposal) + # simple perf model, assume partitioner gives a lowest score around 150GB of memory. + perf = abs(mem - (150 * GB)) + plan = { + "mem": mem, + "proposal": [ + ( + candidate.name, + candidate.compute_kernel, + ( + candidate.cache_params.load_factor + if candidate.cache_params + else None + ), + ) + for candidate in proposal + ], + } + if perf < best_perf: + best_plan = plan + best_perf = perf + proposer.feedback( + partitionable=True, + plan=proposal, + perf_rating=perf, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + self.assertEqual(proposals, 16) + self.assertEqual( + best_plan, + { + # 146GB, close to target of 150GB + "mem": 157178896800, + # table 1 has been scaled up 2.5x more than table 0 (vs original 0.1) + # which aligns with their different cacheability scores + # table_2 has been left alone (deprecated feature, expected zero lookups in stats) + "proposal": [ + ("table_0", "fused_uvm_caching", 0.3173336386680603), + ("table_1", "fused_uvm_caching", 0.6433340907096863), + ("table_2", "fused_uvm_caching", 0.002), + ("table_3", "fused", None), + ], + }, + ) + + def test_promote_high_prefetch_overheaad_table_to_hbm(self) -> None: + def expect_sharding_option( + so: List[ShardingOption], + names: List[str], + total_hbms: List[int], + total_ddrs: List[int], + kernels: List[str], + ) -> None: + for s, name, hbm, ddr, kernel in zip( + so, names, total_hbms, total_ddrs, kernels + ): + self.assertEqual(s.name, name) + self.assertEqual(s.total_storage.hbm, hbm) + self.assertEqual(s.total_storage.ddr, ddr) + self.assertEqual(s.compute_kernel, kernel) + + def mock_storage_estimator_func(so: List[ShardingOption]) -> None: + # This mock storage estimator will give table2 a consistent penalty + # for using UVM caching + for s in so: + size = s.shards[0].size[0] * s.shards[0].size[1] + if s.compute_kernel == "fused_uvm_caching": + assert s.cache_params + assert s.cache_params.load_factor + penalty = 50 if s.name == "table-2" else 0 + s.shards[0].storage = Storage( + ddr=size, + hbm=int(size * s.cache_params.load_factor) + penalty, + ) + else: + s.shards[0].storage = Storage( + ddr=0, + hbm=size, + ) + + mock_enumerator = MagicMock() + mock_enumerator.populate_estimates.side_effect = mock_storage_estimator_func + + # Run1: table-2 is getting penalized but HBM is still lower when offloaded. + # so it won't change after promotion + p = EmbeddingOffloadScaleupProposer() + proposal1 = [ + make_sharding_option("table-1", 50, None), + make_sharding_option("table-2", 90, 0.3), + make_sharding_option("table-3", 100, 0.5), + ] + mock_storage_estimator_func(proposal1) + p.load(search_space=proposal1, enumerator=mock_enumerator) + expect_sharding_option( + p.starting_proposal, + ["table-1", "table-2", "table-3"], + [50, 77, 50], + [0, 90, 100], + ["fused", "fused_uvm_caching", "fused_uvm_caching"], + ) + + # Run1: table-2 will have higher HBM after penalty (90 * 0.8 + 50 > 90). So promote it. + p = EmbeddingOffloadScaleupProposer() + proposal2 = [ + make_sharding_option("table-1", 50, None), + make_sharding_option("table-2", 90, 0.8), + make_sharding_option("table-3", 100, 0.9), + ] + mock_storage_estimator_func(proposal2) + p.load(search_space=proposal2, enumerator=mock_enumerator) + expect_sharding_option( + p.starting_proposal, + ["table-1", "table-2", "table-3"], + [50, 90, 90], + [0, 0, 100], + ["fused", "fused", "fused_uvm_caching"], + ) + + @unittest.mock.patch( + "torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes", + side_effect=mock_calculate_storage_specific_sizes, + ) + def test_budget_shrink(self, _) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10000, + name="table_0", + feature_names=["feature_0"], + ) + ] + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + } + + GB = 1024 * 1024 * 1024 + storage_constraint = Topology( + world_size=1, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB + ) + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, + batch_size=BATCH_SIZE, + constraints=constraints, + estimator=[ + EmbeddingPerfEstimator( + topology=storage_constraint, constraints=constraints + ), + EmbeddingStorageEstimator( + topology=storage_constraint, + constraints=constraints, + ), + ], + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + proposal = proposer.propose() + best_plan = None + best_perf = 1e99 + proposals = -1 + initial_mem = None + while proposal is not None: + proposals += 1 + mem = sum(so.total_storage.hbm for so in proposal) + if initial_mem is None: + initial_mem = mem + # Budget given constraints: + # unscaled plan=7.47 GB, cache scale up budget=92.53 GB, peak scale up budget need=67.06 GB, exploring plans of size [7.47, 74.53] GB + # + # Simple perf model, assume partitioner gives a lowest score at 7.9GB, and + # anything larger than 8GB fails to partition. This is very hard to hit when + # exploring the larger [7.47, 100] range with limited iterations without + # shrinkage. + perf = abs(mem - (7.9 * GB)) + partitionable = mem < 8 * GB + if perf < best_perf: + best_plan = mem + best_perf = perf + proposer.feedback( + partitionable=partitionable, + plan=proposal, + perf_rating=perf if partitionable else None, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + self.assertEqual(proposals, 16) + self.assertNotEqual(initial_mem, best_plan, "couldn't find a better plan") + # goal is 7.9, we get very close + self.assertEqual(best_plan, 7.9028974287211895 * GB) + def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: return [ @@ -375,3 +980,115 @@ def make_mock_proposal(name: str) -> List[ShardingOption]: expected_list_names = ["p1so1", "p1so2", "p2so1", "p2so2", "p3so1", "p3so2"] self.assertEqual(proposals_list_names, expected_list_names) + + def test_embedding_offload_scaleup_proposer_uses_fused_kernel_when_possible( + self, + ) -> None: + def mock_storage_estimator_func(so: List[ShardingOption]) -> None: + # This mock storage estimator will give all tables a penalty + # for using UVM caching. + for s in so: + size = s.shards[0].size[0] * s.shards[0].size[1] + if s.compute_kernel == "fused_uvm_caching": + assert s.cache_params + assert s.cache_params.load_factor + penalty = 50 + s.shards[0].storage = Storage( + ddr=size, + hbm=int(size * s.cache_params.load_factor) + penalty, + ) + else: + s.shards[0].storage = Storage( + ddr=0, + hbm=size, + ) + + mock_enumerator = MagicMock() + mock_enumerator.populate_estimates.side_effect = mock_storage_estimator_func + + p = EmbeddingOffloadScaleupProposer() + search_space = [ + # We should pick the first option since it has a fused_uvm_caching kernel + # even though it doesn't have the best perf. + make_sharding_option("table-1", 5000, 0.1, Perf(9, 9, 9, 9)), + make_sharding_option("table-1", 5000, None, Perf(1, 1, 1, 1)), + make_sharding_option("table-1", 1, None, Perf(1, 1, 1, 1)), + # Neither option has fused_uvm_caching, so we should pick the one with best perf. + make_sharding_option("table-2", 10, None, Perf(9, 9, 9, 9)), + make_sharding_option("table-2", 50, None, Perf(1, 1, 1, 1)), + ] + + mock_storage_estimator_func(search_space) + p.load(search_space=search_space, enumerator=mock_enumerator) + + self.assertEqual(p.starting_proposal[0].name, "table-1") + self.assertEqual(p.starting_proposal[0].compute_kernel, "fused_uvm_caching") + self.assertEqual(p.starting_proposal[0].total_storage.hbm, 550) + self.assertEqual(p.starting_proposal[0].total_storage.ddr, 5000) + + self.assertEqual(p.starting_proposal[1].name, "table-2") + self.assertEqual(p.starting_proposal[1].compute_kernel, "fused") + self.assertEqual(p.starting_proposal[1].total_storage.hbm, 50) + self.assertEqual(p.starting_proposal[1].total_storage.ddr, 0) + + @unittest.mock.patch("torchrec.distributed.planner.proposers.logger") + def test_build_proposal_from_sharding_options(self, mock_logger: MagicMock) -> None: + table_4_sharding_option = make_sharding_option("table-4", 1, 0.1) + assert table_4_sharding_option.cache_params # appease pyre + table_4_sharding_option.cache_params.algorithm = CacheAlgorithm.LFU + + sharding_options_by_fqn = { + # Case 1: Only one option, use it even though it isn't fused_uvm_caching. Don't log anything. + "table-1": [make_sharding_option("table-1", 1, None)], + # Case 2: Multiple options, 1+ with fused_uvm_caching, use the first with fused_uvm_caching. Log warning. + "table-2": [ + make_sharding_option("table-2", 1, None), + make_sharding_option("table-2", 1, 0.1), + make_sharding_option("table-2", 1, None), + make_sharding_option("table-2", 1, 0.1), + ], + # Case 3: Multiple options, none with fused_uvm_caching, use the first one. Log warning. + "table-3": [ + make_sharding_option("table-3", 1, None), + make_sharding_option("table-3", 1, None), + ], + # Case 4: One option, but it's using LFU cache constraints. Use it, but log error. + "table-4": [table_4_sharding_option], + } + + proposer = EmbeddingOffloadScaleupProposer() + proposal = proposer._build_proposal_from_sharding_options( + sharding_options_by_fqn + ) + + self.assertEqual(len(proposal), 4) + self.assertEqual(mock_logger.warning.call_count, 2) + self.assertEqual(mock_logger.error.call_count, 1) + + # Case 1 + self.assertEqual(proposal[0], sharding_options_by_fqn["table-1"][0]) + + # Case 2 + self.assertEqual(proposal[1], sharding_options_by_fqn["table-2"][1]) + self.assertRegex( + mock_logger.warning.call_args_list[0].args[0], + r"^EmbeddingOffloadScaleupProposer - ignored \d+ sharding options for table table-2", + ) + + # Case 3 + self.assertEqual(proposal[2], sharding_options_by_fqn["table-3"][0]) + self.assertRegex( + mock_logger.warning.call_args_list[1].args[0], + r"^EmbeddingOffloadScaleupProposer - ignored \d+ sharding options for table table-3", + ) + + # Case 4 + self.assertEqual(proposal[3], sharding_options_by_fqn["table-4"][0]) + self.assertIn( + "EmbeddingOffloadScaleupProposer - proposer only supports LRU cache algorithm", + mock_logger.error.call_args_list[0].args[0], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index 7a988e51b..093cb4d31 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -5,27 +5,60 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import math import unittest -from typing import cast +from typing import cast, Dict, List, Tuple + +from unittest.mock import MagicMock, Mock, patch import torch import torchrec.optim as trec_optim from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder -from torchrec.distributed.planner.constants import BATCH_SIZE +from torchrec.distributed.fbgemm_qcomm_codec import ( + CommType, + get_qcomm_codecs_registry, + QCommsConfig, +) +from torchrec.distributed.planner.constants import ( + BATCH_SIZE, + CROSS_NODE_BANDWIDTH, + INTRA_NODE_BANDWIDTH, +) from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.shard_estimators import ( _calculate_storage_specific_sizes, + EmbeddingOffloadStats, EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.planner.types import ( + BasicCommsBandwidths, + ParameterConstraints, + Perf, + Topology, ) -from torchrec.distributed.planner.types import Topology from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder -from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.tests.test_quant_model_parallel import _quantize +from torchrec.distributed.test_utils.infer_utils import quantize +from torchrec.distributed.test_utils.test_model import TestEBCSharder, TestSparseNN from torchrec.distributed.tests.test_sequence_model import TestSequenceSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.distributed.types import ( + CacheParams, + CacheStatistics, + ModuleSharder, + MultiPassPrefetchConfig, + PipelineType, + ShardingType, +) +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, +) class TestEmbeddingPerfEstimator(unittest.TestCase): @@ -35,6 +68,29 @@ def setUp(self) -> None: self.enumerator = EmbeddingEnumerator( topology=self.topology, batch_size=BATCH_SIZE, estimator=self.estimator ) + self._sharding_types = [x.value for x in ShardingType] + + def test_basic_comms_bandwidth(self) -> None: + # Ensure the generalized comms setup is identical if we use BasicComms with defaults. + topology2 = Topology( + world_size=2, + compute_device="cuda", + generalized_comms_bandwidths=BasicCommsBandwidths(), + ) + + self.assertEqual(topology2.inter_host_bw, self.topology.inter_host_bw) + self.assertEqual(topology2.intra_host_bw, self.topology.intra_host_bw) + + # Ensure the generalized comms setup is identical if we pass defaults to bw. + topology3 = Topology( + world_size=2, + compute_device="cuda", + intra_host_bw=INTRA_NODE_BANDWIDTH, + inter_host_bw=CROSS_NODE_BANDWIDTH, + ) + + self.assertEqual(topology3.inter_host_bw, self.topology.inter_host_bw) + self.assertEqual(topology3.intra_host_bw, self.topology.intra_host_bw) def test_1_table_perf(self) -> None: tables = [ @@ -46,6 +102,16 @@ def test_1_table_perf(self) -> None: ) ] model = TestSparseNN(tables=tables, weighted_tables=[]) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) sharding_options = self.enumerator.enumerate( module=model, sharders=[ @@ -55,41 +121,217 @@ def test_1_table_perf(self) -> None: expected_perfs = { ("dense", "data_parallel"): [ - 0.0004935158269195386, - 0.0004935158269195386, - ], - ("fused", "table_wise"): [0.0011095368078055323], - ("fused_uvm", "table_wise"): [0.1729105033126532], - ("fused_uvm_caching", "table_wise"): [0.040145097917908434], - ("fused", "column_wise"): [0.0011095368078055323], - ("fused_uvm", "column_wise"): [0.1729105033126532], - ("fused_uvm_caching", "column_wise"): [0.040145097917908434], - ("fused", "table_column_wise"): [0.0011095368078055323], - ("fused_uvm", "table_column_wise"): [0.1729105033126532], - ("fused_uvm_caching", "table_column_wise"): [0.040145097917908434], + Perf( + fwd_compute=9.356002212235228e-05, + fwd_comms=0, + bwd_compute=0.00018712004424470456, + bwd_comms=0.000225593945387348, + ), + Perf( + fwd_compute=9.356002212235228e-05, + fwd_comms=0, + bwd_compute=0.00018712004424470456, + bwd_comms=0.000225593945387348, + ), + ], + ("fused", "table_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm", "table_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm_caching", "table_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused", "column_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm", "column_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm_caching", "column_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused", "table_column_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm", "table_column_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05, + ) + ], + ("fused_uvm_caching", "table_column_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05, + ) + ], ("fused", "row_wise"): [ - 0.00043569201211068144, - 0.00043569201211068144, + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), ], ("fused_uvm", "row_wise"): [ - 0.054393095128676475, - 0.054393095128676475, + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), ], ("fused_uvm_caching", "row_wise"): [ - 0.012695561962491483, - 0.012695561962491483, + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), ], ("fused", "table_row_wise"): [ - 0.00043569201211068144, - 0.00043569201211068144, + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), ], ("fused_uvm", "table_row_wise"): [ - 0.054393095128676475, - 0.054393095128676475, + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), ], ("fused_uvm_caching", "table_row_wise"): [ - 0.012695561962491483, - 0.012695561962491483, + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), + ], + # grid_shard is the same as table_row_wise + ("fused", "grid_shard"): [ + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + ], + ("fused_uvm", "grid_shard"): [ + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.03814697265625, + bwd_comms=0.029329458872477215, + ), + ], + ("fused_uvm_caching", "grid_shard"): [ + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004631910866838161, + ), ], } @@ -103,6 +345,88 @@ def test_1_table_perf(self) -> None: self.assertEqual(expected_perfs, perfs) + def test_1_table_perf_with_fp8_comm(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + model = TestSparseNN(tables=tables, weighted_tables=[]) + + # will get warning for POOLED_EMBEDDINGS_REDUCE_SCATTER not supporting fp8 + qcomm_codecs_registry = get_qcomm_codecs_registry( + qcomms_config=QCommsConfig( + forward_precision=CommType.FP8, backward_precision=CommType.FP8 + ) + ) + + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) + + sharding_options = self.enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], + EmbeddingBagCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry + ), + ) + ], + ) + + expected_total_perfs = { + ("dense", "data_parallel"): [0.0005062740117544049, 0.0005062740117544049], + ("fused", "table_wise"): [0.000846718200207288], + ("fused_uvm", "table_wise"): [0.22846659024556476], + ("fused_uvm_caching", "table_wise"): [0.03568990443780169], + ("fused", "column_wise"): [0.000846718200207288], + ("fused_uvm", "column_wise"): [0.22846659024556476], + ("fused_uvm_caching", "column_wise"): [0.03568990443780169], + ("fused", "table_column_wise"): [0.000846718200207288], + ("fused_uvm", "table_column_wise"): [0.22846659024556476], + ("fused_uvm_caching", "table_column_wise"): [0.03568990443780169], + ("fused", "row_wise"): [0.0002561205605599394, 0.0002561205605599394], + ("fused_uvm", "row_wise"): [0.05403558413187663, 0.05403558413187663], + ("fused_uvm_caching", "row_wise"): [ + 0.008488476760988312, + 0.008488476760988312, + ], + ("fused", "table_row_wise"): [0.0002561205605599394, 0.0002561205605599394], + ("fused_uvm", "table_row_wise"): [0.05403558413187663, 0.05403558413187663], + ("fused_uvm_caching", "table_row_wise"): [ + 0.008488476760988312, + 0.008488476760988312, + ], + ("fused", "grid_shard"): [0.0002561205605599394, 0.0002561205605599394], + ("fused_uvm", "grid_shard"): [0.05403558413187663, 0.05403558413187663], + ("fused_uvm_caching", "grid_shard"): [ + 0.008488476760988312, + 0.008488476760988312, + ], + } + + total_perfs = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [cast(Perf, shard.perf).total for shard in sharding_option.shards] + for sharding_option in sharding_options + } + + self.assertEqual(expected_total_perfs, total_perfs) + def test_sequence_2_table_perf(self) -> None: tables = [ EmbeddingConfig( @@ -126,37 +450,31 @@ def test_sequence_2_table_perf(self) -> None: ], ) - expected_perfs = { - ("dense", "data_parallel"): [ - 0.002677347614879459, - 0.002677347614879459, - ], + expected_total_perfs = { + ("dense", "data_parallel"): [0.0026901057997143255, 0.0026901057997143255], ("fused", "table_wise"): [0.001880471390093715], - ("fused_uvm", "table_wise"): [0.25958192114736517], - ("fused_uvm_caching", "table_wise"): [0.060433813055248066], + ("fused_uvm", "table_wise"): [0.41346708933512366], + ("fused_uvm_caching", "table_wise"): [0.06488458897040142], ("fused", "column_wise"): [0.001880471390093715], - ("fused_uvm", "column_wise"): [0.25958192114736517], - ("fused_uvm_caching", "column_wise"): [0.060433813055248066], - ("fused", "row_wise"): [ - 0.0007915177871551004, - 0.0007915177871551004, - ], - ("fused_uvm", "row_wise"): [0.1036341050091912, 0.1036341050091912], + ("fused_uvm", "column_wise"): [0.41346708933512366], + ("fused_uvm_caching", "column_wise"): [0.06488458897040142], + ("fused", "row_wise"): [0.0007915177871551004, 0.0007915177871551004], + ("fused_uvm", "row_wise"): [0.16504605611165366, 0.16504605611165366], ("fused_uvm_caching", "row_wise"): [ - 0.024158779217047007, - 0.024158779217047007, + 0.025934979198424798, + 0.025934979198424798, ], } - perfs = { + total_perfs = { ( sharding_option.compute_kernel, sharding_option.sharding_type, - ): [shard.perf for shard in sharding_option.shards] + ): [cast(Perf, shard.perf).total for shard in sharding_option.shards] for sharding_option in sharding_options } - self.assertEqual(expected_perfs, perfs) + self.assertEqual(expected_total_perfs, total_perfs) def test_inference_1_table_perf(self) -> None: tables = [ @@ -168,7 +486,7 @@ def test_inference_1_table_perf(self) -> None: ) ] model = TestSparseNN(tables=tables, weighted_tables=[]) - quant_model = _quantize(model, inplace=True) + quant_model = quantize(model, inplace=True) inference_estimator = EmbeddingPerfEstimator( topology=self.topology, is_inference=True @@ -185,21 +503,233 @@ def test_inference_1_table_perf(self) -> None: ], ) - expected_perfs = { + expected_total_perfs = { ("quant", "table_wise"): [0.0001296231579222408], - ("quant_uvm", "table_wise"): [0.018350937787224266], - ("quant_uvm_caching", "table_wise"): [0.004269758427175579], + ("quant_uvm", "table_wise"): [0.029231707255045574], + ("quant_uvm_caching", "table_wise"): [0.004584459754509654], + ("quant", "row_wise"): [5.5200413052187844e-05, 5.5200413052187844e-05], + ("quant_uvm", "row_wise"): [0.008370081583658854, 0.008370081583658854], + ("quant_uvm_caching", "row_wise"): [ + 0.0013280108692200203, + 0.0013280108692200203, + ], + ("quant", "column_wise"): [0.0001296231579222408], + ("quant_uvm", "column_wise"): [0.029231707255045574], + ("quant_uvm_caching", "column_wise"): [0.004584459754509654], } - perfs = { + total_perfs = { ( sharding_option.compute_kernel, sharding_option.sharding_type, - ): [shard.perf for shard in sharding_option.shards] + ): [cast(Perf, shard.perf).total for shard in sharding_option.shards] for sharding_option in sharding_options } - self.assertEqual(perfs, expected_perfs) + self.assertEqual(total_perfs, expected_total_perfs) + + def test_prefetch_compute(self) -> None: + class MyCacheStatistics(CacheStatistics): + def __init__(self, expected_lookups: int, cacheability: float) -> None: + self._expected_lookups = expected_lookups + self._cacheability = cacheability + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + return clf + + @property + def cacheability(self) -> float: + return self._cacheability + + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ), + ] + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), + ), + ), + # simulate promoting a uvm caching table to HBM during scaleup. + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + cache_params=CacheParams( + load_factor=None, + stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), + ), + ), + } + enumerator = EmbeddingEnumerator( + topology=self.topology, + batch_size=BATCH_SIZE, + estimator=self.estimator, + constraints=constraints, + ) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + enumerator._filter_sharding_types = MagicMock(return_value=self._sharding_types) + model = TestSparseNN(tables=tables, weighted_tables=[]) + sharding_options = enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], + EmbeddingBagCollectionSharder( + fused_params={"cache_load_factor": 0.2} + ), + ) + ], + ) + + expected_prefetch_computes = { + ("table_0", "fused_uvm_caching", "column_wise"): [0.023283064365386963], + ("table_0", "fused_uvm_caching", "row_wise"): [ + 0.011641532182693481, + 0.011641532182693481, + ], + ("table_0", "fused_uvm_caching", "table_column_wise"): [ + 0.023283064365386963 + ], + ("table_0", "fused_uvm_caching", "table_row_wise"): [ + 0.011641532182693481, + 0.011641532182693481, + ], + ("table_0", "fused_uvm_caching", "grid_shard"): [ + 0.011641532182693481, + 0.011641532182693481, + ], + ("table_0", "fused_uvm_caching", "table_wise"): [0.023283064365386963], + ("table_1", "fused", "column_wise"): [0.0], + ("table_1", "fused", "row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_column_wise"): [0.0], + ("table_1", "fused", "table_row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_wise"): [0.0], + ("table_1", "fused", "grid_shard"): [0.0, 0.0], + } + + prefetch_computes = { + ( + sharding_option.name, + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [ + shard.perf.prefetch_compute if shard.perf else -1 + for shard in sharding_option.shards + ] + for sharding_option in sharding_options + } + self.assertEqual(expected_prefetch_computes, prefetch_computes) + + def test_weighted_feature_bwd_compute_multiplier(self) -> None: + def _get_bwd_computes( + model: torch.nn.Module, + weighted_feature_bwd_compute_multiplier: float, + ) -> Dict[Tuple[str, str, str], List[float]]: + topology = Topology( + world_size=2, + compute_device="cuda", + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + ) + estimator = EmbeddingPerfEstimator(topology=topology) + enumerator = EmbeddingEnumerator( + topology=topology, batch_size=BATCH_SIZE, estimator=estimator + ) + sharding_options = enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder() + ) + ], + ) + bwd_computes = { + ( + sharding_option.name, + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [ + shard.perf.bwd_compute if shard.perf else -1 + for shard in sharding_option.shards + ] + for sharding_option in sharding_options + } + return bwd_computes + + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="weighted_table_0", + feature_names=["weighted_feature_0"], + ) + ] + model = TestSparseNN(tables=tables, weighted_tables=weighted_tables) + + MULTIPLIER = 7 + bwd_computes_1 = _get_bwd_computes( + model, weighted_feature_bwd_compute_multiplier=1 + ) + bwd_computes_2 = _get_bwd_computes( + model, + weighted_feature_bwd_compute_multiplier=2, + ) + bwd_computes_n = _get_bwd_computes( + model, + weighted_feature_bwd_compute_multiplier=MULTIPLIER, + ) + self.assertEqual(bwd_computes_1.keys(), bwd_computes_2.keys()) + self.assertEqual(bwd_computes_1.keys(), bwd_computes_n.keys()) + for key in bwd_computes_1.keys(): + table_name, _, sharding_type = key + if table_name.startswith("weighted"): + self.assertEqual(len(bwd_computes_1), len(bwd_computes_2)) + self.assertEqual(len(bwd_computes_1), len(bwd_computes_n)) + for bwd_compute_1, bwd_compute_2, bwd_compute_n in zip( + bwd_computes_1[key], bwd_computes_2[key], bwd_computes_n[key] + ): + # bwd_compute_1 = base_bwd_compute + offset + # bwd_compute_2 = base_bwd_compute * 2 + offset + # bwd_compute_n = base_bwd_compute * MULTIPLIER + offset + # (where offset = bwd_grad_indice_weights_kernel in production + # https://fburl.com/code/u9hq6vhf) + base_bwd_compute = bwd_compute_2 - bwd_compute_1 + offset = bwd_compute_1 - base_bwd_compute + self.assertAlmostEqual( + base_bwd_compute * MULTIPLIER, + bwd_compute_n - offset, + ) + else: + self.assertEqual(bwd_computes_1[key], bwd_computes_2[key]) # pyre-ignore[3] @@ -209,35 +739,627 @@ def calculate_storage_specific_size_data_provider(): "sharding_type": ShardingType.TABLE_ROW_WISE, "optimizer_class": torch.optim.SGD, "expected_storage": [50, 50], + "clf": None, }, { "sharding_type": ShardingType.COLUMN_WISE, "optimizer_class": torch.optim.Adam, - "expected_storage": [150, 150], + "expected_storage": [ + 150 + math.ceil(5 * (4 + 0.5 * 16)), + 150 + math.ceil(5 * (4 + 0.5 * 16)), + ], + "clf": 0.5, }, { "sharding_type": ShardingType.TABLE_ROW_WISE, "optimizer_class": None, - "expected_storage": [50, 50], + "expected_storage": [ + 50 + math.ceil(5 * (4 + 0.0 * 16)), + 50 + math.ceil(5 * (4 + 0.0 * 16)), + ], + "clf": 0.0, }, { "sharding_type": ShardingType.DATA_PARALLEL, "optimizer_class": trec_optim.RowWiseAdagrad, - "expected_storage": [134, 134], + "expected_storage": [ + 134 + math.ceil(5 * (4 + 1.0 * 16)), + 134 + math.ceil(5 * (4 + 1.0 * 16)), + ], + "clf": 1.0, }, ) +class TestEmbeddingPerfEstimatorWithGeneralizedComms(unittest.TestCase): + def setUp(self) -> None: + # Testing with non-default invokes BasicCommsBandwidths. + self.topology = Topology( + world_size=2, + compute_device="cuda", + inter_host_bw=40.0 * 1024**3 / 1000, + intra_host_bw=300.0 * 1024**3 / 1000, + ) + self.estimator = EmbeddingPerfEstimator(topology=self.topology) + self.enumerator = EmbeddingEnumerator( + topology=self.topology, batch_size=BATCH_SIZE, estimator=self.estimator + ) + self._sharding_types = [x.value for x in ShardingType] + + self.topology2 = Topology( + world_size=2, + compute_device="cuda", + generalized_comms_bandwidths=BasicCommsBandwidths( + inter_host_bw=40.0 * 1024**3 / 1000, + intra_host_bw=300.0 * 1024**3 / 1000, + ), + ) + self.estimator2 = EmbeddingPerfEstimator(topology=self.topology2) + self.enumerator2 = EmbeddingEnumerator( + topology=self.topology2, batch_size=BATCH_SIZE, estimator=self.estimator2 + ) + self._sharding_types2 = [x.value for x in ShardingType] + + def test_1_table_perf(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + model = TestSparseNN(tables=tables, weighted_tables=[]) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) + sharding_options = self.enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + self.enumerator2._filter_sharding_types = MagicMock( + return_value=self._sharding_types2 + ) + sharding_options2 = self.enumerator2.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + + expected_perfs = { + ("dense", "data_parallel"): [ + Perf( + fwd_compute=9.356002212235228e-05, + fwd_comms=0, + bwd_compute=0.00018712004424470456, + bwd_comms=0.00012314846217964537, + ), + Perf( + fwd_compute=9.356002212235228e-05, + fwd_comms=0, + bwd_compute=0.00018712004424470456, + bwd_comms=0.00012314846217964537, + ), + ], + ("fused", "table_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05 + * 2, # bw is set to half in this test + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05 + * 2, # bw is set to half in this test + ) + ], + ("fused_uvm", "table_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused_uvm_caching", "table_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused", "column_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused_uvm", "column_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused_uvm_caching", "column_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused", "table_column_wise"): [ + Perf( + fwd_compute=0.000327460077428233, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.000654920154856466, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused_uvm", "table_column_wise"): [ + Perf( + fwd_compute=0.09179115295410156, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.18358230590820312, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused_uvm_caching", "table_column_wise"): [ + Perf( + fwd_compute=0.01432837509527439, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.02865675019054878, + bwd_comms=6.357828776041667e-05 * 2, + ) + ], + ("fused", "row_wise"): [ + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + ], + ("fused_uvm", "row_wise"): [ + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + ], + ("fused_uvm_caching", "row_wise"): [ + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + ), + ], + ("fused", "table_row_wise"): [ + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + ], + ("fused_uvm", "table_row_wise"): [ + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + ], + ("fused_uvm_caching", "table_row_wise"): [ + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + ), + ], + # grid_shard is the same as table_row_wise + ("fused", "grid_shard"): [ + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + ), + ], + ("fused_uvm", "grid_shard"): [ + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=0.019073486328125, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.03814697265625, + bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + ), + ], + ("fused_uvm_caching", "grid_shard"): [ + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05, + ), + Perf( + fwd_compute=0.0029773246951219513, + fwd_comms=6.357828776041667e-05 * 2, + bwd_compute=0.0059546493902439025, + bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + ), + ], + } + + perfs = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [shard.perf for shard in sharding_option.shards] + for sharding_option in sharding_options + } + + perfs2 = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [shard.perf for shard in sharding_option.shards] + for sharding_option in sharding_options2 + } + self.assertEqual(expected_perfs, perfs) + self.assertEqual(expected_perfs, perfs2) + + class TestEmbeddingStorageEstimator(unittest.TestCase): def test_calculate_storage_specific_sizes(self) -> None: for inputs in calculate_storage_specific_size_data_provider(): - sharding_type, optimizer_class, expected_storage = inputs.values() + sharding_type, optimizer_class, expected_storage, clf = inputs.values() estimates = _calculate_storage_specific_sizes( storage=100, shape=torch.Size((10, 5, 3)), shard_sizes=[[5, 5, 3], [5, 5, 3]], sharding_type=sharding_type.value, optimizer_class=optimizer_class, + clf=clf, ) self.assertEqual(estimates, expected_storage) + + @patch( + "torchrec.distributed.planner.shard_estimators._calculate_shard_io_sizes", + return_value=([1024], [3333]), + ) + @patch( + "torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes", + return_value=[100], + ) + def test_pipelined_storage(self, p1: Mock, p2: Mock) -> None: + for pipeline_type in list(PipelineType): + for run_embedding_at_peak_memory in [False, True]: + topology = Topology(world_size=2, compute_device="cuda") + estimator = EmbeddingStorageEstimator( + topology=topology, + pipeline_type=pipeline_type, + run_embedding_at_peak_memory=run_embedding_at_peak_memory, + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_2", + feature_names=["feature_2"], + ), + ] + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ], + sharding_types=[ShardingType.TABLE_WISE.value], + cache_params=CacheParams( + load_factor=0.1, + ), + ), + # simulate promoting a uvm caching table to HBM during scaleup. + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + sharding_types=[ShardingType.TABLE_WISE.value], + cache_params=CacheParams( + load_factor=None, + ), + ), + "table_2": ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ], + sharding_types=[ShardingType.TABLE_WISE.value], + cache_params=CacheParams( + load_factor=0.1, + multipass_prefetch_config=MultiPassPrefetchConfig( + num_passes=10, + ), + ), + ), + } + enumerator = EmbeddingEnumerator( + topology=topology, + batch_size=BATCH_SIZE, + estimator=estimator, + constraints=constraints, + ) + + model = TestSparseNN(tables=tables, weighted_tables=[]) + sharding_options = enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], + EmbeddingBagCollectionSharder( + fused_params={ + "cache_load_factor": 0.2, + } + ), + ) + ], + ) + + output_on_pipeline = 3333 if run_embedding_at_peak_memory else 0 + if pipeline_type == PipelineType.TRAIN_SPARSE_DIST: + expected_storage = { + ("table_0", "fused_uvm_caching", "table_wise"): [ + (100 + 2048 + output_on_pipeline, 100) + ], + ("table_1", "fused", "table_wise"): [ + (100 + 2048 + output_on_pipeline, 100) + ], + ("table_2", "fused_uvm_caching", "table_wise"): [ + (100 + 2048 + output_on_pipeline, 100) + ], + } + elif pipeline_type == PipelineType.TRAIN_PREFETCH_SPARSE_DIST: + expected_storage = { + ("table_0", "fused_uvm_caching", "table_wise"): [ + (100 + 1024 * 10 + output_on_pipeline, 100) + ], + ("table_1", "fused", "table_wise"): [ + (100 + 3072 + output_on_pipeline, 100) + ], + ("table_2", "fused_uvm_caching", "table_wise"): [ + (100 + 1024 * 3 + int(1024 * 1.6) + output_on_pipeline, 100) + ], + } + else: + # Backward compatible path, using old formula when pipeline + # type is None or unrecognized. + expected_storage = { + ("table_0", "fused_uvm_caching", "table_wise"): [ + (100 + 3333 + 1024, 100) + ], + ("table_1", "fused", "table_wise"): [(100 + 3333 + 1024, 100)], + ("table_2", "fused_uvm_caching", "table_wise"): [ + (100 + 3333 + 1024, 100) + ], + } + actual_storage = { + ( + sharding_option.name, + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [ + (shard.storage.hbm, shard.storage.ddr) + for shard in sharding_option.shards + if shard.storage is not None + ] + for sharding_option in sharding_options + } + self.assertEqual(expected_storage, actual_storage) + + def test_default_output_sizes(self) -> None: + topology = Topology(world_size=2, compute_device="cuda") + constraint_list = [ + None, + {"table_0": ParameterConstraints(output_dtype=DataType.FP32)}, + ] + + table_list = [ + [ + EmbeddingBagConfig( + num_embeddings=50, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + data_type=DataType.FP32, + ) + ], + [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + data_type=DataType.FP16, + ) + ], + ] + hbms = [] + + for tables, constraints in zip(table_list, constraint_list): + enumerator = EmbeddingEnumerator( + topology=topology, batch_size=BATCH_SIZE, constraints=constraints + ) + model = TestSparseNN(tables=tables, weighted_tables=[]) + sharding_options = enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], + TestEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + ), + ) + ], + ) + self.assertEqual(len(sharding_options), 1) + self.assertEqual(len(sharding_options[0].shards), 1) + self.assertIsNotNone(sharding_options[0].shards[0].storage) + hbms.append(sharding_options[0].shards[0].storage.hbm) # pyre-ignore + + self.assertEqual(hbms[0], hbms[1]) + + +class TestEmbeddingOffloadStats(unittest.TestCase): + def test_basic(self) -> None: + stats = EmbeddingOffloadStats( + cacheability=0.42, + expected_lookups=31, + mrc_hist_counts=torch.tensor([99, 98, 97]), + height=92, + ) + self.assertEqual(stats.cacheability, 0.42) + self.assertEqual(stats.expected_lookups, 31) + self.assertEqual(stats.expected_miss_rate(0), 1.0) + self.assertEqual(stats.expected_miss_rate(1), 0.0) + self.assertAlmostEqual( + stats.expected_miss_rate(0.5), 1 - (99 + 98) / (99 + 98 + 97) + ) + + def test_estimate_cache_miss_rate(self) -> None: + hist = torch.tensor([0, 6, 0, 8]) + bins = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor([0, 1, 2, 3, 4]), hist, bins + ) + m = 1 - (6 / (6 + 8)) # from hist counts above + want = [ + 1, # size 0 - 100% miss + 1, # size 1 - 100%, no immediate repetitions + m, # size 2 - m (~57%) miss, 6 occurrences + m, # size 3 - same as size 2, no 3 stack distances, + # so increasing cache by 1 doesn't help + 0, # size 4 - 0% miss rate, everything fits + ] + torch.testing.assert_close(miss_rates, torch.tensor(want)) + + # test with bigger bins to better validate boundary conditions + # create simple linear miss rate curve + trace = torch.arange(100.0) + hist = torch.histc(trace, bins=10, min=0, max=100) + bins = torch.linspace(0, 100, len(hist) + 1) + cache_heights = [0, 9, 10, 11, 89, 99, 100] + miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor(cache_heights), hist, bins + ) + want = [ + 1, # 0 -> no cache, 100% miss + 0.9, # 9 -> bin 0, which is all cache sizes <= 10, has 90 misses of 100, so 90% miss + 0.9, # 10 -> bin 0, same as above + 0.8, # 11 -> bin 1, cache sizes (10, 20], 80 misses out of 100, so 80% miss + 0.1, # 89 -> bin 8, cache sizes (80, 90], 10 misses out of 100, so 10% miss + 0, # 99 -> bin 9, cache sizes (90, 100], final last bin gets scaled to 1, so 0% misses + 0, # 100 -> off the end of the histogram, 0% misses + ] + torch.testing.assert_close(miss_rates, torch.tensor(want)) + # test using 0-d tensors works as well + miss_rates = torch.tensor( + [ + EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor(x), hist, bins + ) + for x in cache_heights + ] + ) + torch.testing.assert_close(miss_rates, torch.tensor(want)) + + # test features no with no data return non-nan + hist = torch.tensor([0, 0]) + bins = torch.tensor([0, 1, 2]) + miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor([0, 1, 2]), hist, bins + ) + torch.testing.assert_close(miss_rates, torch.tensor([0.0, 0.0, 0.0])) + # test 0-d case + miss_rates = torch.tensor( + [ + EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor(x), hist, bins + ) + for x in [0, 1, 2] + ] + ) + torch.testing.assert_close(miss_rates, torch.tensor([0.0, 0.0, 0.0])) diff --git a/torchrec/distributed/planner/tests/test_stats.py b/torchrec/distributed/planner/tests/test_stats.py new file mode 100644 index 000000000..10849b517 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_stats.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +import unittest +from typing import List + +import hypothesis.strategies as st + +import torch +from hypothesis import given, settings +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner +from torchrec.distributed.planner.stats import ( + _calc_max_chi_sq_divergence, + _calc_max_kl_divergence, + _chi_sq_divergence, + _kl_divergence, + _normalize_float, + _normalize_int, + _total_distance, + _total_variation, + EmbeddingStats, + NoopEmbeddingStats, +) +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TWvsRWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.FUSED.value] + + +class TestEmbeddingStats(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + + def test_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner(topology=self.topology) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 1) + stats_data = planner._stats[0] + assert isinstance(stats_data, EmbeddingStats) + stats: List[str] = stats_data._stats_table + self.assertTrue(isinstance(stats, list)) + self.assertTrue(stats[0].startswith("####")) + + def test_empty_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner(topology=self.topology, stats=[]) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 0) + + def test_noop_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner( + topology=self.topology, stats=NoopEmbeddingStats() + ) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 1) + + def test_embedding_stats_output_with_top_hbm_usage(self) -> None: + planner = EmbeddingShardingPlanner(topology=self.topology) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 1) + stats_data = planner._stats[0] + assert isinstance(stats_data, EmbeddingStats) + stats: List[str] = stats_data._stats_table + self.assertTrue(isinstance(stats, list)) + top_hbm_usage_keyword = "Top HBM Memory Usage Estimation:" + self.assertTrue(any(top_hbm_usage_keyword in row for row in stats)) + top_hbm_mem_usage = None + for row in stats: + if top_hbm_usage_keyword in row: + top_hbm_mem_usage = float(row.split(" ")[6]) + self.assertIsNotNone(top_hbm_mem_usage) + + def test_normalize_float(self) -> None: + p = [2.0, 2.0] + self.assertEqual(_normalize_float(p), [0.5, 0.5]) + + def test_normalize_int(self) -> None: + p = [2, 2] + self.assertEqual(_normalize_int(p), [0.5, 0.5]) + + def test_total_variation(self) -> None: + p_1 = [0.5, 0.5] + self.assertEqual(_total_variation(p_1), 0.0) + + p_2 = [0.0, 1.0] + self.assertEqual(_total_variation(p_2), 0.5) + + def test_total_distance(self) -> None: + p_1 = [0.5, 0.5] + self.assertEqual(_total_distance(p_1), 0.0) + + p_2 = [0.0, 1.0] + self.assertEqual(_total_distance(p_2), 1.0) + + def test_chi_divergence(self) -> None: + p_1 = [0.5, 0.5] + self.assertEqual(_chi_sq_divergence(p_1), 0.0) + + p_2 = [0.0, 1.0] + self.assertEqual(_chi_sq_divergence(p_2), 1.0) + + def test_kl_divergence(self) -> None: + p_1 = [0.5, 0.5] + self.assertEqual(_kl_divergence(p_1), 0.0) + + p_2 = [0.1, 0.9] + self.assertAlmostEqual(_kl_divergence(p_2), 0.368, 3) + + # pyre-ignore + @given( + N=st.integers(min_value=10, max_value=200), + ) + @settings(max_examples=4, deadline=None) + def test_kl_divergence_upper_bound(self, N: int) -> None: + # Generate most imbalanced distribution + normalized_p = [ + 1.0, + ] + [ + 0.0 + ] * (N - 1) + N = len(normalized_p) + self.assertEqual(_kl_divergence(normalized_p), _calc_max_kl_divergence(N)) + + # pyre-ignore + @given( + N=st.integers(min_value=10, max_value=200), + alpha=st.floats(min_value=1.0, max_value=5.0), + ) + @settings(max_examples=4, deadline=None) + def test_chi_divergence_upper_bound(self, N: int, alpha: float) -> None: + # Generate most imbalanced distribution + normalized_p = [ + 1.0, + ] + [ + 0.0 + ] * (N - 1) + N = len(normalized_p) + + self.assertTrue( + math.isclose( + _chi_sq_divergence(normalized_p), + _calc_max_chi_sq_divergence(N), + abs_tol=1e-10, + ) + ) diff --git a/torchrec/distributed/planner/tests/test_storage_reservations.py b/torchrec/distributed/planner/tests/test_storage_reservations.py index f5d3dc1df..4bc2b40f3 100644 --- a/torchrec/distributed/planner/tests/test_storage_reservations.py +++ b/torchrec/distributed/planner/tests/test_storage_reservations.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import cast, List @@ -144,3 +146,36 @@ def test_storage_reservations_tower_nested_sharders(self) -> None: # pyre-ignore heuristical_storage_reservation._dense_storage.hbm, ) + + def test_storage_reservations_with_dense_estimation(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + + ebc = EmbeddingBagCollection(tables) + model = TestModel(shardable_sparse=ebc) + + dense_tensor_estimate = 1000000 + heuristical_storage_reservation = HeuristicalStorageReservation( + percentage=0.0, dense_tensor_estimate=dense_tensor_estimate + ) + + heuristical_storage_reservation.reserve( + topology=Topology(world_size=2, compute_device="cuda"), + batch_size=10, + module=model, + sharders=cast( + List[ModuleSharder[nn.Module]], [EmbeddingBagCollectionSharder()] + ), + ) + + self.assertEqual( + dense_tensor_estimate, + # pyre-ignore + heuristical_storage_reservation._dense_storage.hbm, + ) diff --git a/torchrec/distributed/planner/tests/test_types.py b/torchrec/distributed/planner/tests/test_types.py new file mode 100644 index 000000000..bb8280f00 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_types.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast +from unittest.mock import MagicMock + +import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + +from torchrec.distributed.planner.types import Shard, ShardingOption +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, + ShardingType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionCollection, + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionModule, + MCHManagedCollisionModule, +) + + +class TestShardingOption(unittest.TestCase): + def test_hash_sharding_option(self) -> None: + shard_size = [10000, 80] + shard_offsets = [[0, 0], [0, 80]] + sharding_option: ShardingOption = ShardingOption( + name="table_0", + tensor=torch.empty( + (10000, 160), dtype=torch.float16, device=torch.device("meta") + ), + module=("ebc", MagicMock()), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED.value, + shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets], + cache_params=CacheParams( + algorithm=CacheAlgorithm.LRU, + load_factor=0.5, + reserved_memory=0.0, + precision=DataType.FP16, + prefetch_pipeline=True, + ), + enforce_hbm=True, + stochastic_rounding=False, + bounds_check_mode=BoundsCheckMode.WARNING, + ) + self.assertTrue(map(hash, [sharding_option])) + + def test_module_pooled_ebc(self) -> None: + eb_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=160, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + ebc = EmbeddingBagCollection(tables=[eb_config]) + + sharding_option: ShardingOption = ShardingOption( + name="table_0", + tensor=torch.empty( + (10000, 160), dtype=torch.float16, device=torch.device("meta") + ), + module=("ebc", ebc), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED.value, + shards=[ + Shard(size=[10000, 80], offset=offset) for offset in [[0, 0], [0, 80]] + ], + ) + self.assertEqual(sharding_option.is_pooled, True) + + def test_module_pooled_mch_ebc(self) -> None: + eb_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=160, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + ebc = EmbeddingBagCollection(tables=[eb_config]) + mc_modules = { + "table_0": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=10000, + device=torch.device("meta"), + eviction_interval=1, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=[eb_config], + ) + mch_ebc = ManagedCollisionEmbeddingBagCollection(ebc, mcc) + + sharding_option: ShardingOption = ShardingOption( + name="table_0", + tensor=torch.empty( + (10000, 80), dtype=torch.float16, device=torch.device("meta") + ), + module=("mch_ebc", mch_ebc), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED.value, + shards=[ + Shard(size=[10000, 80], offset=offset) for offset in [[0, 0], [0, 80]] + ], + ) + self.assertEqual(sharding_option.is_pooled, True) + + def test_module_pooled_ec(self) -> None: + e_config = EmbeddingConfig( + name="table_0", + embedding_dim=80, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + ec = EmbeddingCollection(tables=[e_config]) + + shard_size = [10000, 80] + shard_offsets = [[0, 0], [0, 80]] + sharding_option: ShardingOption = ShardingOption( + name="table_0", + tensor=torch.empty( + (10000, 160), dtype=torch.float16, device=torch.device("meta") + ), + module=("ec", ec), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED.value, + shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets], + ) + self.assertEqual(sharding_option.is_pooled, False) + + def test_module_pooled_mch_ec(self) -> None: + e_config = EmbeddingConfig( + name="table_0", + embedding_dim=80, + num_embeddings=10000, + feature_names=["f1"], + data_type=DataType.FP16, + ) + ec = EmbeddingCollection(tables=[e_config]) + mc_modules = { + "table_0": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=10000, + device=torch.device("meta"), + eviction_interval=1, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=[e_config], + ) + mch_ec = ManagedCollisionEmbeddingCollection(ec, mcc) + + shard_size = [10000, 80] + shard_offsets = [[0, 0], [0, 80]] + sharding_option: ShardingOption = ShardingOption( + name="table_0", + tensor=torch.empty( + (10000, 160), dtype=torch.float16, device=torch.device("meta") + ), + module=("mch_ec", mch_ec), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED.value, + shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets], + ) + self.assertEqual(sharding_option.is_pooled, False) diff --git a/torchrec/distributed/planner/tests/test_utils.py b/torchrec/distributed/planner/tests/test_utils.py new file mode 100644 index 000000000..803e263f2 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_utils.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +import unittest +from typing import Callable, List, Optional +from unittest.mock import MagicMock + +import torch + +from torchrec.distributed.planner.types import Perf, Shard, ShardingOption, Storage +from torchrec.distributed.planner.utils import ( + _find_imbalance_tables, + BinarySearchPredicate, + LuusJaakolaSearch, + reset_shard_rank, +) +from torchrec.distributed.types import ShardingType + + +class TestFindImbalanceTables(unittest.TestCase): + def setUp(self) -> None: + self.best_plan: List[ShardingOption] = [] + for i in range(10): + shard_size = [100 * i, 8] + shard_offsets = [[0, 0], [0, 8]] + self.best_plan.append( + ShardingOption( + name=f"table_{i}", + tensor=MagicMock(), + module=MagicMock(), + input_lengths=MagicMock(), + batch_size=MagicMock(), + sharding_type=ShardingType.COLUMN_WISE.value, + partition_by=MagicMock(), + compute_kernel=MagicMock(), + shards=[ + Shard(size=shard_size, offset=offset) + for offset in shard_offsets + ], + ) + ) + + def test_find_perf_imbalance_tables(self) -> None: + reset_shard_rank(self.best_plan) + for i, sharding_option in enumerate(self.best_plan): + for j, shard in enumerate(sharding_option.shards): + shard.rank = 2 * i + j + shard.perf = Perf( + fwd_compute=2 * i, + fwd_comms=2 * i, + bwd_compute=2 * i, + bwd_comms=2 * i, + ) + + expected_max_perf_table_names = ["table_9"] + max_perf_table_names = [ + sharding_option.name + for sharding_option in _find_imbalance_tables(self.best_plan) + ] + self.assertTrue(expected_max_perf_table_names, max_perf_table_names) + + def test_find_hbm_imbalance_tables(self) -> None: + reset_shard_rank(self.best_plan) + for i, sharding_option in enumerate(self.best_plan): + for j, shard in enumerate(sharding_option.shards): + shard.rank = 2 * i + j + shard.storage = Storage( + hbm=2 * (10 - i), + ddr=0, + ) + + expected_max_hbm_table_names = ["table_0"] + max_hbm_table_names = [ + sharding_option.name + for sharding_option in _find_imbalance_tables( + self.best_plan, target_imbalance="hbm" + ) + ] + self.assertTrue(expected_max_hbm_table_names, max_hbm_table_names) + + +class TestBinarySearchPredicate(unittest.TestCase): + def test_binary_search_predicate(self) -> None: + def F(x: int) -> bool: + return x < 90 + + def probes( + search: BinarySearchPredicate, f: Callable[[int], bool] + ) -> List[int]: + r = [] + probe = search.next(True) + while probe is not None: + r.append(probe) + probe = search.next(f(probe)) + return r + + got = probes(BinarySearchPredicate(0, 100, 0), F) + self.assertEqual(got, [50, 75, 88, 94, 91, 89, 90]) + got = probes(BinarySearchPredicate(0, 100, 3), F) + self.assertEqual(got, [50, 75, 88, 94, 91]) + got = probes(BinarySearchPredicate(0, 100, 20), F) + self.assertEqual(got, [50, 75, 88]) + + got = probes(BinarySearchPredicate(91, 100, 0), F) + self.assertEqual(got, [95, 92, 91]) + got = probes(BinarySearchPredicate(1, 10, 0), F) + self.assertEqual(got, [5, 8, 9, 10]) + + got = probes(BinarySearchPredicate(1, 1, 0), F) + self.assertEqual(got, [1]) + got = probes(BinarySearchPredicate(1, 0, 0), F) + self.assertEqual(got, []) + + +class TestLuusJaakolaSearch(unittest.TestCase): + + # Find minimum of f between x0 and x1. + # Evaluate multiple times with different random seeds to ensure we're not + # just getting lucky. + # Returns a Nx2 tensor of [xs, ys] of discovered minimums. + @staticmethod + def evaluate( + x0: float, + x1: float, + f: Callable[[float], float], + left_cost: Optional[float] = None, + ) -> torch.Tensor: + xs = [] + ys = [] + iterations = 16 + for i in range(5): + search = LuusJaakolaSearch(x0, x1, iterations, seed=i, left_cost=left_cost) + y = search.next(0.0) + while y is not None: + fy = f(y) + y = search.next(fy) + x, y = search.best() + xs.append(x) + ys.append(y) + return torch.stack([torch.tensor(xs), torch.tensor(ys)], dim=1) + + def test_simple(self) -> None: + # See N4816561 to view these results graphically + def f1(x: float) -> float: + return x + + def f2(x: float) -> float: + return x * x - 10 * x + 10 # min at x = 5 + + def f3(x: float) -> float: + # bumpy function, overall min at x=30 + return (x - 30) ** 2 + 100 * math.sin(x) + + def f4(x: float) -> float: + # spiky/non-smooth function, min at x = 30 + return (x - 30) ** 2 + (x % 10) * 100 + + results = TestLuusJaakolaSearch.evaluate(0, 100, f1) + want = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], dtype=torch.int64) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f2) + want = torch.tensor( + [ + [3.51914, -12.80705], + [4.22958, -14.40646], + [5.41303, -14.82940], + [2.35012, -7.97811], + [4.18552, -14.33662], + ] + ) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f3) + want = torch.tensor( + [ + [36.58517, -46.37988], + [29.73184, -99.28705], + [37.67208, 56.15779], + [35.85468, -62.00219], + [41.76223, 58.69744], + ] + ) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate(0, 100, f4) + want = torch.tensor( + [ + [23.68681, 408.53735], + [31.62534, 165.17535], + [32.81968, 289.91898], + [42.81567, 445.80777], + [22.53002, 308.80225], + ] + ) + torch.testing.assert_close(results, want) + + def test_iterations(self) -> None: + search = LuusJaakolaSearch(0, 1, 3) + y = search.next(0) + probes = 0 + while y is not None: + probes += 1 + fy = y + y = search.next(fy) + self.assertEqual(probes, 3) + + # https://github.com/pytorch/pytorch/issues/50334 + @staticmethod + def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: + """One-dimensional linear interpolation for monotonically increasing sample + points. + + Returns the one-dimensional piecewise linear interpolant to a function with + given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. + + Args: + x: the :math:`x`-coordinates at which to evaluate the interpolated + values. + xp: the :math:`x`-coordinates of the data points, must be increasing. + fp: the :math:`y`-coordinates of the data points, same length as `xp`. + + Returns: + the interpolated values, same size as `x`. + """ + m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) + b = fp[:-1] - (m * xp[:-1]) + + indicies = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 + indicies = torch.clamp(indicies, 0, len(m) - 1) + + return m[indicies] * x + b[indicies] + + def test_real(self) -> None: + # See N4816561 to view these results graphically + + # Real data collected from bin packing has non-smooth surface and many local minimums. + # mem vs cost from cmf_icvr bin packing + cmf_icvr = torch.tensor( + [ + [4.6741845183e11, 2.3563506569e02], + [4.6741845240e11, 2.3563506569e02], + [4.7121749230e11, 2.3506600864e02], + [4.7501653103e11, 2.3468280680e02], + [4.7881557076e11, 2.3430065943e02], + [4.8261460996e11, 2.3396533990e02], + [4.8641364892e11, 2.3367888393e02], + [4.9021268717e11, 2.3339395760e02], + [4.9401172728e11, 2.3316084540e02], + [4.9781076708e11, 2.3292654771e02], + [5.0160980674e11, 2.3275780179e02], + [5.0540884491e11, 2.3256067684e02], + [5.0920788486e11, 2.3235742684e02], + [5.1300692424e11, 2.3219262609e02], + [5.1680596356e11, 2.3206849693e02], + [5.2060500162e11, 2.3193348320e02], + [5.2440404195e11, 2.3180536764e02], + [5.2820308146e11, 2.3170546631e02], + [5.3200212032e11, 2.3158138440e02], + [5.3580115967e11, 2.3146545816e02], + [5.3960019895e11, 2.3138856778e02], + [5.4339923878e11, 2.3128211641e02], + [5.4719827815e11, 2.3121699239e02], + [5.5099731798e11, 2.3169756090e02], + [5.5479635643e11, 2.3103278320e02], + [5.5859539575e11, 2.3171106005e02], + [5.6239443438e11, 2.3091072319e02], + [5.6619349259e11, 2.3084920287e02], + [5.6999251415e11, 2.3078335619e02], + [5.7379155310e11, 2.3113596330e02], + [5.7759059204e11, 2.3069988094e02], + [5.8138963104e11, 2.3127273113e02], + [5.8518866978e11, 2.3172034584e02], + [5.8898770984e11, 2.3083009711e02], + [5.9278674971e11, 2.3080842049e02], + [5.9658578920e11, 2.3176370343e02], + [6.0038482804e11, 2.3071235199e02], + [6.0418386709e11, 2.3213900014e02], + [6.0798290658e11, 2.3332448570e02], + [6.1178194561e11, 2.3275468168e02], + [6.1558098586e11, 2.3028775311e02], + [6.1938002497e11, 2.3099002246e02], + [6.2317906405e11, 2.3169044278e02], + [6.2697810321e11, 2.3387964670e02], + [6.3077714335e11, 2.3211138392e02], + [6.3457618280e11, 2.3106450194e02], + [6.3837522051e11, 2.3392878354e02], + [6.4217426058e11, 2.3260742338e02], + [6.4597330044e11, 2.3212726336e02], + [6.4977233953e11, 2.3355375214e02], + [6.5357137911e11, 2.3370492744e02], + [6.5737041818e11, 2.3274859312e02], + [6.6116945832e11, 2.3454963160e02], + [6.6496849695e11, 2.3314306687e02], + [6.6876753631e11, 2.3387508611e02], + [6.7256657578e11, 2.3164114924e02], + [6.7636561494e11, 2.3335876240e02], + [6.8016465549e11, 2.3259160444e02], + [6.8396369350e11, 2.3472844839e02], + [6.8776273363e11, 2.3402051674e02], + [6.9156177298e11, 2.3574191998e02], + [6.9536081174e11, 2.3853930635e02], + [6.9915984917e11, 2.3440978885e02], + [7.0295889084e11, 2.3613333429e02], + [7.0675792895e11, 2.3783556448e02], + [7.1055696937e11, 2.3596357613e02], + [7.1435600664e11, 2.4035834255e02], + [7.1815504705e11, 2.3882352229e02], + [7.2195408724e11, 2.4316494619e02], + [7.2575312535e11, 2.4125740709e02], + [7.2955216606e11, 2.3700425464e02], + [7.3335120460e11, 2.4198517463e02], + [7.3715024347e11, 2.4290543902e02], + [7.4094928544e11, 2.3961167246e02], + [7.4474832211e11, 2.4162098068e02], + [7.4854736178e11, 2.4791162259e02], + [7.5234640124e11, 2.4706576073e02], + [7.5614544041e11, 2.4682659631e02], + [7.5994447978e11, 2.4839164423e02], + [7.6374351905e11, 2.5108968132e02], + [7.6754255785e11, 2.5344371602e02], + [7.7134159724e11, 2.6063943014e02], + [7.7514063682e11, 2.4953670969e02], + [7.7893967570e11, 2.5865807123e02], + [7.8273871453e11, 2.6094569799e02], + [7.8653775458e11, 2.6653191005e02], + [7.9033679421e11, 2.6909497473e02], + [7.9413583349e11, 2.7149400968e02], + [7.9793487494e11, 2.7245403781e02], + [8.0173391173e11, 2.8131908812e02], + [8.0553295106e11, 2.9112192412e02], + [8.0933199067e11, 2.9245070076e02], + [8.1313102998e11, 2.8235347505e02], + [8.1693006950e11, 2.9033406803e02], + [8.2072910826e11, 3.0580905927e02], + [8.2452814772e11, 3.1147292572e02], + [8.2832723864e11, 3.0812470431e02], + [8.3212622721e11, 3.4879506066e02], + [8.3592526617e11, 3.2790815984e02], + [8.3972430401e11, 3.6465536216e02], + [8.4352334347e11, 3.9066552303e02], + ], + dtype=torch.float64, + ) + + mem: torch.Tensor = cmf_icvr[:, 0] + cost: torch.Tensor = cmf_icvr[:, 1] + + def f(x: float) -> float: + return TestLuusJaakolaSearch.interp(torch.tensor([x]), mem, cost).item() + + results = TestLuusJaakolaSearch.evaluate(mem.min().item(), mem.max().item(), f) + want = torch.tensor( + [ + [5.370294e11, 2.314406e02], + [5.426136e11, 2.313041e02], + [5.908549e11, 2.308194e02], + [5.755533e11, 2.309337e02], + [6.184178e11, 2.308121e02], + ], + ) + torch.testing.assert_close(results, want) + + results = TestLuusJaakolaSearch.evaluate( + mem.min().item(), mem.max().item(), f, left_cost=cost[0].item() + ) + want = torch.tensor( + [ + [5.370294e11, 2.314406e02], + # 2nd search finds better result given left_cost + [5.918126e11, 2.308140e02], + [5.908549e11, 2.308194e02], + [5.755533e11, 2.309337e02], + [6.184178e11, 2.308121e02], + ], + ) + torch.testing.assert_close(results, want) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index e474ba2b5..7643c232f 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc from copy import deepcopy from dataclasses import dataclass, field @@ -15,14 +17,89 @@ from torch import nn from torchrec.distributed.planner.constants import ( BATCH_SIZE, + BWD_COMPUTE_MULTIPLIER, CROSS_NODE_BANDWIDTH, DDR_CAP, + DDR_MEM_BW, HBM_CAP, + HBM_MEM_BW, + HBM_TO_DDR_MEM_BW, INTRA_NODE_BANDWIDTH, POOLING_FACTOR, + WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, +) +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheParams, + KeyValueParams, + ModuleSharder, + ShardingPlan, ) -from torchrec.distributed.types import ModuleSharder, ShardingPlan +from torchrec.modules.embedding_configs import DataType from torchrec.modules.embedding_modules import EmbeddingCollectionInterface +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection + +# ---- Perf ---- # + + +@dataclass(repr=True, eq=True) +class Perf: + """ + Representation of the breakdown of the perf estimate a single shard of an + embedding table. + """ + + fwd_compute: float + fwd_comms: float + bwd_compute: float + bwd_comms: float + prefetch_compute: float = 0.0 + + @property + def total(self) -> float: + # When using embedding offload, there is a prefetch compute component. This + # prefetch can overlap with fwd_compute + fwd_comm and dense fwd (some of it + # overlaps with fwd_compute) and dense bwd. (fwd_compute and bwd_compute are + # embedding fwd/bwd, nothing to do with dense). Only when prefetch is longer + # than fwd_compute + dense_fwd + dense_bwd it will block bwd_compute. However, + # we don't have an effective way to estimate dense fwd/bwd at this point, so our + # cost model is too simplistic. Instead prefetch is always considered blocking. + # + # Also note, measuring prefetch blocking can only be done after partitioning, + # here are only have the per shard estimates. + # + # However adding a per-shard prefetch component to the cost model does have the + # benefit that 1) it enables the ScaleupProposer to explore the trade off + # between increasing cache sizes vs more difficult bin-packing constraints. 2) + # it helps balance the prefetch compute across the ranks. + return ( + self.fwd_compute + + self.bwd_compute + + self.fwd_comms + + self.bwd_comms + + self.prefetch_compute + ) + + def __add__(self, other: "Perf") -> "Perf": + return Perf( + fwd_compute=self.fwd_compute + other.fwd_compute, + fwd_comms=self.fwd_comms + other.fwd_comms, + bwd_compute=self.bwd_compute + other.bwd_compute, + bwd_comms=self.bwd_comms + other.bwd_comms, + prefetch_compute=self.prefetch_compute + other.prefetch_compute, + ) + + def __hash__(self) -> int: + return hash( + ( + self.fwd_compute, + self.fwd_comms, + self.bwd_compute, + self.bwd_comms, + self.prefetch_compute, + ) + ) + # ---- TOPOLOGY ---- # @@ -64,7 +141,110 @@ class DeviceHardware: rank: int storage: Storage - perf: float = 0 + perf: Perf + + +class CustomTopologyData: + """ + Custom device data for individual device in a topology. + """ + + supported_fields = ["ddr_cap", "hbm_cap"] + + def __init__( + self, + data: Dict[str, List[int]], + world_size: int, + ) -> None: + assert all( + key in self.supported_fields for key in data.keys() + ), f"{data.keys()} not supported in CustomTopologyData" + assert all( + len(v) == world_size for v in data.values() + ), f"{data.values()} must be positive" + self._data = data + self._world_size = world_size + + def get_data(self, key: str) -> List[int]: + assert ( + key in self.supported_fields + ), f"{key} not supported in CustomTopologyData" + return self._data[key] + + def has_data(self, key: str) -> bool: + return key in self._data + + +class CollectiveType(Enum): + ALL_TO_ALL = "all_to_all" + REDUCE_SCATTER = "reduce_scatter" + ALL_GATHER = "all_gather" + ALL_REDUCE = "all_reduce" + + +class GeneralizedCommsBandwidth(abc.ABC): + @abc.abstractmethod + def get_bw( + self, + local_world_size: int, + world_size: int, + collective_type: CollectiveType, + ) -> float: + """ + Get Bandwidth Corresponding to a collective communication where involving world_size ranks + spread equally across world_size / local_world_size nodes + """ + pass + + @property + @abc.abstractmethod + def intra_host_bw(self) -> float: + """this must be implemented for backward compatibility""" + pass + + @property + @abc.abstractmethod + def inter_host_bw(self) -> float: + """this must be implemented for backward compatibility""" + pass + + +class BasicCommsBandwidths(GeneralizedCommsBandwidth): + def __init__( + self, + inter_host_bw: float = CROSS_NODE_BANDWIDTH, + intra_host_bw: float = INTRA_NODE_BANDWIDTH, + ) -> None: + self.name = "BasicCommsBandwidths" + self._inter_host_bw = inter_host_bw + self._intra_host_bw = intra_host_bw + + def __str__(self) -> str: + return ( + self.name + + f": inter_host_bw={self.inter_host_bw}, intra_host_bw={self.intra_host_bw}" + ) + + @property + def inter_host_bw(self) -> float: + return self._inter_host_bw + + @property + def intra_host_bw(self) -> float: + return self._intra_host_bw + + def get_bw( + self, + local_world_size: int, + world_size: int, + collective_type: CollectiveType, + ) -> float: + if collective_type == CollectiveType.ALL_REDUCE: + return self.inter_host_bw * local_world_size # 1 NIC per GPU + if world_size <= local_world_size: + return self.intra_host_bw + else: + return self.inter_host_bw class Topology: @@ -73,41 +253,86 @@ def __init__( world_size: int, compute_device: str, hbm_cap: Optional[int] = None, - ddr_cap: int = DDR_CAP, + ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, + hbm_mem_bw: float = HBM_MEM_BW, + ddr_mem_bw: float = DDR_MEM_BW, + hbm_to_ddr_mem_bw: float = HBM_TO_DDR_MEM_BW, intra_host_bw: float = INTRA_NODE_BANDWIDTH, inter_host_bw: float = CROSS_NODE_BANDWIDTH, + bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER, + custom_topology_data: Optional[CustomTopologyData] = None, + weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, + uneven_sharding_perf_multiplier: float = 1.0, + generalized_comms_bandwidths: Optional[GeneralizedCommsBandwidth] = None, ) -> None: """ Representation of a network of devices in a cluster. + + If a GeneralizedCommsBandwidth is passed to generalized_comms_bandwidths, this object will + take precedence over the formulation using only intra_host_bw and inter_host_bw. + If it's not passed, we will create a BasicCommsBandwidths object with the provided bandwidths. """ # validate input assert compute_device in [ "cpu", "cuda", + "mtia", ], f"unsupported compute device {compute_device}" self._compute_device = compute_device self._world_size = world_size - hbm_per_device = 0 + hbm_per_device = [0] * world_size if self._compute_device == "cuda": - hbm_per_device = hbm_cap if hbm_cap else HBM_CAP + hbm_per_device = [hbm_cap if hbm_cap else HBM_CAP] * world_size + ddr_cap_per_rank = [ddr_cap if ddr_cap else DDR_CAP] * world_size + + if custom_topology_data: + if custom_topology_data.has_data("hbm_cap"): + hbm_per_device = custom_topology_data.get_data("hbm_cap") + assert ( + len(hbm_per_device) == world_size + ), "Must provide individual hbm_cap for each device" + if custom_topology_data.has_data("ddr_cap"): + ddr_cap_per_rank = custom_topology_data.get_data("ddr_cap") + assert ( + len(ddr_cap_per_rank) == world_size + ), "Must provide individual ddr_cap for each device" self._devices: List[DeviceHardware] = [] for rank in range(world_size): self._devices.append( DeviceHardware( rank=rank, - storage=Storage(hbm=hbm_per_device, ddr=ddr_cap), + storage=Storage( + hbm=hbm_per_device[rank], ddr=ddr_cap_per_rank[rank] + ), + perf=Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0), ) ) self._local_world_size: int = ( local_world_size if local_world_size else world_size ) - self._intra_host_bw = intra_host_bw - self._inter_host_bw = inter_host_bw + self._hbm_mem_bw = hbm_mem_bw + self._ddr_mem_bw = ddr_mem_bw + self._hbm_to_ddr_mem_bw = hbm_to_ddr_mem_bw + + self._comms_bandwidths: GeneralizedCommsBandwidth = ( + generalized_comms_bandwidths + if generalized_comms_bandwidths is not None + else BasicCommsBandwidths( + intra_host_bw=intra_host_bw, inter_host_bw=inter_host_bw + ) + ) + + self._bwd_compute_multiplier = bwd_compute_multiplier + self._custom_topology_data = custom_topology_data + self._weighted_feature_bwd_compute_multiplier = ( + weighted_feature_bwd_compute_multiplier + ) + self._uneven_sharding_perf_multiplier = uneven_sharding_perf_multiplier @property def compute_device(self) -> str: @@ -125,13 +350,41 @@ def world_size(self) -> int: def local_world_size(self) -> int: return self._local_world_size + @property + def hbm_mem_bw(self) -> float: + return self._hbm_mem_bw + + @property + def ddr_mem_bw(self) -> float: + return self._ddr_mem_bw + + @property + def hbm_to_ddr_mem_bw(self) -> float: + return self._hbm_to_ddr_mem_bw + @property def intra_host_bw(self) -> float: - return self._intra_host_bw + return self._comms_bandwidths.intra_host_bw @property def inter_host_bw(self) -> float: - return self._inter_host_bw + return self._comms_bandwidths.inter_host_bw + + @property + def comms_bandwidths(self) -> GeneralizedCommsBandwidth: + return self._comms_bandwidths + + @property + def bwd_compute_multiplier(self) -> float: + return self._bwd_compute_multiplier + + @property + def weighted_feature_bwd_compute_multiplier(self) -> float: + return self._weighted_feature_bwd_compute_multiplier + + @property + def uneven_sharding_perf_multiplier(self) -> float: + return self._uneven_sharding_perf_multiplier def __repr__(self) -> str: topology_repr: str = f"world_size={self._world_size} \n" @@ -140,8 +393,7 @@ def __repr__(self) -> str: for idx, device in enumerate(self._devices): topology_repr += f"\tdevice {idx} {device}\n" topology_repr += f"local_world_size={self._local_world_size} \n" - topology_repr += f"intra_host_bw={self._intra_host_bw} \n" - topology_repr += f"inter_host_bw={self._inter_host_bw} \n" + topology_repr += str(self._comms_bandwidths) + "\n" return topology_repr @@ -159,7 +411,7 @@ class Shard: size: List[int] offset: List[int] storage: Optional[Storage] = None - perf: Optional[float] = None + perf: Optional[Perf] = None rank: Optional[int] = None def __hash__(self) -> int: @@ -173,10 +425,52 @@ def __hash__(self) -> int: ) ) + def __str__(self) -> str: + return f"Shard size: {tuple(self.size)}, offset: {tuple(self.offset)}, storage: {str(self.storage)}, perf: {str(self.perf)}, rank: {self.rank}" + class ShardingOption: """ - One way of sharding an embedding table. + One way of sharding an embedding table. In the enumerator, we generate + multiple sharding options per table, but in the planner output, there + should only be one sharding option per table. + + Attributes: + name (str): name of the sharding option. + tensor (torch.Tensor): tensor of the sharding option. Usually on meta + device. + module (Tuple[str, nn.Module]): module and its fqn that contains the + table. + input_lengths (List[float]): list of pooling factors of the feature for + the table. + batch_size (int): batch size of training / eval job. + sharding_type (str): sharding type of the table. Value of enum ShardingType. + compute_kernel (str): compute kernel of the table. Value of enum + EmbeddingComputeKernel. + shards (List[Shard]): list of shards of the table. + cache_params (Optional[CacheParams]): cache parameters to be used by this table. + These are passed to FBGEMM's Split TBE kernel. + enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when + using cache. + stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is + passed to FBGEMM's Split TBE kernel. Stochastic rounding is + non-deterministic, but important to maintain accuracy in longer + term with FP16 embedding tables. + bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by + FBGEMM's Split TBE kernel. Bounds check means checking if values + (i.e. row id) is within the table size. If row id exceeds table + size, it will be set to 0. + dependency (Optional[str]): dependency of the table. Related to + Embedding tower. + is_pooled (Optional[bool]): whether the table is pooled. Pooling can be + sum pooling or mean pooling. Unpooled tables are also known as + sequence embeddings. + feature_names (Optional[List[str]]): list of feature names for this table. + output_dtype (Optional[DataType]): output dtype to be used by this table. + The default is FP32. If not None, the output dtype will also be used + by the planner to produce a more balanced plan. + key_value_params (Optional[KeyValueParams]): Params for SSD TBE, either + for SSD or PS. """ def __init__( @@ -190,7 +484,15 @@ def __init__( partition_by: str, compute_kernel: str, shards: List[Shard], + cache_params: Optional[CacheParams] = None, + enforce_hbm: Optional[bool] = None, + stochastic_rounding: Optional[bool] = None, + bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None, + is_pooled: Optional[bool] = None, + feature_names: Optional[List[str]] = None, + output_dtype: Optional[DataType] = None, + key_value_params: Optional[KeyValueParams] = None, ) -> None: self.name = name self._tensor = tensor @@ -203,7 +505,16 @@ def __init__( # relevant to planner output, must be populated if sharding option # part of final solution self.shards = shards + self.cache_params = cache_params + self.enforce_hbm = enforce_hbm + self.stochastic_rounding = stochastic_rounding + self.bounds_check_mode = bounds_check_mode self.dependency = dependency + self._is_pooled = is_pooled + self.is_weighted: Optional[bool] = None + self.feature_names: Optional[List[str]] = feature_names + self.output_dtype: Optional[DataType] = output_dtype + self.key_value_params: Optional[KeyValueParams] = key_value_params @property def tensor(self) -> torch.Tensor: @@ -217,6 +528,12 @@ def module(self) -> Tuple[str, nn.Module]: def fqn(self) -> str: return self.module[0] + "." + self.name + @property + def cache_load_factor(self) -> Optional[float]: + if self.cache_params is not None: + return self.cache_params.load_factor + return None + @property def path(self) -> str: return self.module[0] @@ -236,14 +553,36 @@ def total_storage(self) -> Storage: storage += cast(Storage, shard.storage) return storage + @property + def total_perf(self) -> float: + perf: float = 0 + for shard in self.shards: + # pyre-ignore: Undefined attribute [16] + perf += shard.perf.total + return perf + @property def is_pooled(self) -> bool: - if isinstance(self.module[1], EmbeddingCollectionInterface): + if self._is_pooled is None: + self._is_pooled = ShardingOption.module_pooled(self.module[1], self.name) + return self._is_pooled + + @staticmethod + def module_pooled(module: nn.Module, sharding_option_name: str) -> bool: + """Determine if module pools output (e.g. EmbeddingBag) or uses unpooled/sequential output.""" + if isinstance(module, EmbeddingCollectionInterface) or isinstance( + module, ManagedCollisionEmbeddingCollection + ): return False - for name, module in self.module[1].named_modules(): - if self.name in name: - if isinstance(module, EmbeddingCollectionInterface): - return False + + for submodule in module.modules(): + if isinstance(submodule, EmbeddingCollectionInterface) or isinstance( + submodule, ManagedCollisionEmbeddingCollection + ): + for name, _ in submodule.named_parameters(): + if sharding_option_name in name: + return False + return True def __hash__(self) -> int: @@ -253,6 +592,7 @@ def __hash__(self) -> int: self.sharding_type, self.compute_kernel, tuple(self.shards), + self.cache_params, ) ) @@ -268,6 +608,17 @@ def __deepcopy__( setattr(result, k, deepcopy(v, memo)) return result + def __str__(self) -> str: + str_obj: str = "" + str_obj += f"name: {self.name}" + str_obj += f"\nsharding type: {self.sharding_type}" + str_obj += f"\ncompute kernel: {self.compute_kernel}" + str_obj += f"\nnum shards: {len(self.shards)}" + for shard in self.shards: + str_obj += f"\n\t{str(shard)}" + + return str_obj + class PartitionByType(Enum): """ @@ -280,6 +631,8 @@ class PartitionByType(Enum): HOST = "host" # Uniform, (ie. fixed layout) UNIFORM = "uniform" + # Partitioning based on multiple hosts + MULTI_HOST = "multi_host" @dataclass @@ -289,17 +642,66 @@ class ParameterConstraints: If provided, `pooling_factors`, `num_poolings`, and `batch_sizes` must match in length, as per sample. + + Attributes: + sharding_types (Optional[List[str]]): sharding types allowed for the table. + Values of enum ShardingType. + compute_kernels (Optional[List[str]]): compute kernels allowed for the table. + Values of enum EmbeddingComputeKernel. + min_partition (Optional[int]): lower bound for dimension of column wise shards. + Planner will search for the column wise shard dimension in the + range of [min_partition, embedding_dim], as long as the column wise + shard dimension divides embedding_dim and is divisible by 4. Used + for column wise sharding only. + pooling_factors (Optional[List[float]]): pooling factors for each feature of the + table. This is the average number of values each sample has for + the feature. Length of pooling_factors should match the number of + features. + num_poolings (OptionalList[float]]): number of poolings for each feature of the + table. Length of num_poolings should match the number of features. + batch_sizes (Optional[List[int]]): batch sizes for each feature of the table. Length + of batch_sizes should match the number of features. + is_weighted (Optional[bool]): whether the table is weighted. + cache_params (Optional[CacheParams]): cache parameters to be used by this table. + These are passed to FBGEMM's Split TBE kernel. + enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when + using cache. + stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is + passed to FBGEMM's Split TBE kernel. Stochastic rounding is + non-deterministic, but important to maintain accuracy in longer + term with FP16 embedding tables. + bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by + FBGEMM's Split TBE kernel. Bounds check means checking if values + (i.e. row id) is within the table size. If row id exceeds table + size, it will be set to 0. + feature_names (Optional[List[str]]): list of feature names for this table. + output_dtype (Optional[DataType]): output dtype to be used by this table. + The default is FP32. If not None, the output dtype will also be used + by the planner to produce a more balanced plan. + device_group (Optional[str]): device group to be used by this table. It can be cpu + or cuda. This specifies if the table should be placed on a cpu device + or a gpu device. + key_value_params (Optional[KeyValueParams]): key value params for SSD TBE, either for + SSD or PS. """ sharding_types: Optional[List[str]] = None compute_kernels: Optional[List[str]] = None - min_partition: Optional[int] = None # CW sharding + min_partition: Optional[int] = None # CW sharding, min CW dim to shard pooling_factors: List[float] = field( default_factory=lambda: [POOLING_FACTOR] ) # average number of embedding lookups required per sample num_poolings: Optional[List[float]] = None # number of poolings per sample in batch batch_sizes: Optional[List[int]] = None # batch size per input feature is_weighted: bool = False + cache_params: Optional[CacheParams] = None + enforce_hbm: Optional[bool] = None + stochastic_rounding: Optional[bool] = None + bounds_check_mode: Optional[BoundsCheckMode] = None + feature_names: Optional[List[str]] = None + output_dtype: Optional[DataType] = None + device_group: Optional[str] = None + key_value_params: Optional[KeyValueParams] = None class PlannerErrorType(Enum): @@ -339,14 +741,12 @@ def reserve( module: nn.Module, sharders: List[ModuleSharder[nn.Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None, - ) -> Topology: - ... + ) -> Topology: ... class PerfModel(abc.ABC): @abc.abstractmethod - def rate(self, plan: List[ShardingOption]) -> float: - ... + def rate(self, plan: List[ShardingOption]) -> float: ... class ShardEstimator(abc.ABC): @@ -359,8 +759,7 @@ def __init__( self, topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def estimate( @@ -385,8 +784,7 @@ def __init__( batch_size: int = BATCH_SIZE, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def enumerate( @@ -399,6 +797,13 @@ def enumerate( """ ... + @abc.abstractmethod + def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: + """ + See class description. + """ + ... + class Proposer(abc.ABC): """ @@ -410,7 +815,15 @@ class Proposer(abc.ABC): def load( self, search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, ) -> None: + """ + Load search space into proposer. + + Args: + search_space (List[ShardingOption]): search space to load. + enumerator (Enumerator): enumerator used to generate search space. + """ ... @abc.abstractmethod @@ -419,11 +832,27 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: + """ + Provide feedback to proposer. + + Args: + partitionable (bool): whether the plan is partitionable. + plan (Optional[List[ShardingOption]]): plan to provide feedback on. + perf_rating (Optional[float]): performance rating of the plan. + storage_constraint (Optional[Topology]): storage constraint of the plan. + """ ... @abc.abstractmethod def propose(self) -> Optional[List[ShardingOption]]: + """ + Propose a sharding plan. + + Returns: + Optional[List[ShardingOption]]: proposed plan. + """ ... @@ -461,9 +890,19 @@ def log( run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, debug: bool = False, ) -> None: """ See class description """ ... + + +@dataclass +class CriticalPathEstimate: + comms_estimate: float + comp_estimate: float + + def total(self) -> float: + return self.comms_estimate + self.comp_estimate diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 687541bef..5d0a03a27 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -5,11 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import math import operator from functools import reduce -from typing import Any, Iterable, Type, Union +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type, Union import torch +from torchrec.distributed.planner.types import Perf, ShardingOption, Storage +from torchrec.distributed.types import ShardingType + # pyre-ignore[2] def sharder_name(t: Type[Any]) -> str: @@ -42,6 +48,216 @@ def placement( """ param_device = compute_device - if compute_device == "cuda": - param_device = torch.device("cuda", rank % local_size) + if compute_device in {"cuda", "mtia"}: + param_device = torch.device(compute_device, rank % local_size) return f"rank:{rank}/{param_device}" + + +def storage_repr_in_gb(storage: Optional[Storage]) -> str: + if storage is None: + return "" + return ( + f"Storage(hbm = {round(bytes_to_gb(storage.hbm), 3)} GB, " + f"ddr = {round(bytes_to_gb(storage.ddr), 3)} GB)" + ) + + +def reset_shard_rank(proposal: List[ShardingOption]) -> None: + for sharding_option in proposal: + for shard in sharding_option.shards: + shard.rank = None + + +def _find_imbalance_tables( + sharding_options: List[ShardingOption], target_imbalance: str = "perf" +) -> List[ShardingOption]: + """ + Find the tables that are causing the imbalance, and return their names. + """ + rank_to_target_stats: Dict[int, float] = {} + + # populate rank_to_target_stats + for sharding_option in sharding_options: + for shard in sharding_option.shards: + rank = cast(int, shard.rank) + if rank not in rank_to_target_stats: + rank_to_target_stats[rank] = 0 + + if target_imbalance == "perf": + rank_to_target_stats[rank] += cast(Perf, shard.perf).total + elif target_imbalance == "hbm": + rank_to_target_stats[rank] += cast(Storage, shard.storage).hbm + else: + raise ValueError(f"Unknown target imbalance {target_imbalance}") + + if len(rank_to_target_stats.values()) <= 1: + # world_size is 1 + return [] + + max_value = max(rank_to_target_stats.values()) + max_value_ranks = { + rank for rank, value in rank_to_target_stats.items() if value == max_value + } + + # find tables + tables_in_max_value_ranks: List[ShardingOption] = [] + for sharding_option in sharding_options: + sharding_option_ranks = [shard.rank for shard in sharding_option.shards] + if set( + sharding_option_ranks + ) >= max_value_ranks and sharding_option.sharding_type not in [ + ShardingType.DATA_PARALLEL.value, + ShardingType.ROW_WISE.value, + ]: + tables_in_max_value_ranks.append(sharding_option) + + if target_imbalance == "perf": + # sort tables by total perf from largest to smallest + tables_in_max_value_ranks.sort( + key=lambda sharding_option: sharding_option.shards[0].perf.total, + reverse=True, + ) + elif target_imbalance == "hbm": + # sort tables by hbm from largest to smallest + tables_in_max_value_ranks.sort( + key=lambda sharding_option: sharding_option.shards[0].storage.hbm, + reverse=True, + ) + else: + raise ValueError(f"Unknown target imbalance {target_imbalance}") + + return tables_in_max_value_ranks + + +class BinarySearchPredicate: + """Generates values of X between A & B to invoke on an external predicate F(X) to + discover the largest X for which F(X) is true. Uses binary search to minimize the + number of invocations of F. Assumes F is a step function, i.e. if F(X) is false, + there is no point trying F(X+1).""" + + def __init__(self, A: int, B: int, tolerance: int) -> None: + """A = lower boundary (inclusive) + B = upper boundary (inclusive) + tolerance = stop search early if remaining search range is less than tolerance + """ + self.left = A + self.right = B + self.tolerance = tolerance + self.first = True + + def next(self, prior_result: bool) -> Optional[int]: + """next() returns the next value to probe, given the result of the prior probe. + The first time next() is invoked the prior_result is ignored. Returns None if + entire range explored or threshold reached.""" + if self.right - self.left < self.tolerance: + return None + + mid = self._mid() + if self.first: + self.first = False + return mid + + if prior_result: + self.left = mid + 1 + else: + self.right = mid - 1 + if self.right - self.left < self.tolerance: + return None + + return self._mid() + + def _mid(self) -> int: + return self.left + ((self.right - self.left) // 2) + + +class LuusJaakolaSearch: + """Implements a clamped variant of Luus Jaakola search. + + See https://en.wikipedia.org/wiki/Luus-Jaakola. + """ + + def __init__( + self, + A: float, + B: float, + max_iterations: int, + seed: int = 42, + left_cost: Optional[float] = None, + ) -> None: + self.left = A + self.right = B + self.iteration = -1 + self.max_iterations = max_iterations + + self.gen = torch.Generator() + self.gen.manual_seed(seed) + + self.x: float = self.uniform(self.left, self.right) + self.fx: float = 0.0 + self.y: float = math.nan + self.fleft: Optional[float] = left_cost + self.fright: Optional[float] = None + self.d: float = self.right - self.left + + def shrink_right(self, B: float) -> None: + "Shrink right boundary given [B,infinity) -> infinity" + self.right = B + self.fright = math.inf + self.d = self.right - self.left + self.x = self.clamp(self.x) + + def clamp(self, x: float) -> float: + "Clamp x into range [left, right]" + if x < self.left: + return self.left + if x > self.right: + return self.right + return x + + def uniform(self, A: float, B: float) -> float: + "Return a random uniform position in range [A,B]." + u = torch.rand(1, generator=self.gen, device="cpu").item() + return A + (B - A) * u + + def next(self, fy: float) -> Optional[float]: + """Return the next probe point 'y' to evaluate, given the previous result. + + The first time around fy is ignored. Subsequent invocations should provide the + result of evaluating the function being minimized, i.e. f(y). + + Returns None when the maximum number of iterations has been reached. + """ + self.iteration += 1 + if self.iteration == 0: + return self.x + elif self.iteration == 1: + self.fx = fy + elif self.iteration == self.max_iterations: + return None + elif fy <= self.fx: + self.x = self.y + self.fx = fy + self.d = 0.95 * self.d + + if self.y == self.left: + self.fleft = fy + elif self.y == self.right: + self.fright = fy + + while True: + a = self.uniform(-self.d, self.d) + y = self.clamp(self.x + a) + # Unlike standard Luus-Jaakola, we don't want to explore outside of our bounds. + # Clamping can cause us to explore the boundary multiple times, so we + # remember if we already know the boundary cost and request a new sample if + # we do. + if y == self.left and self.fleft is not None: + continue + if y == self.right and self.fright is not None: + continue + self.y = y + return self.y + + def best(self) -> Tuple[float, float]: + "Return the best position so far, and its associated cost." + return self.x, self.fx diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 02b71dda8..792fdeb0a 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -5,50 +5,97 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import logging +from collections import defaultdict, deque from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Type +from typing import ( + Any, + cast, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) import torch +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torch.distributed._shard.sharding_spec import EnumerableShardingSpec from torchrec.distributed.embedding import ( - create_sharding_infos_by_sharding, + create_sharding_infos_by_sharding_device_group, EmbeddingShardingInfo, ) -from torchrec.distributed.embedding_sharding import ( - EmbeddingSharding, - ListOfSparseFeaturesListAwaitable, -) +from torchrec.distributed.embedding_sharding import EmbeddingSharding from torchrec.distributed.embedding_types import ( BaseQuantEmbeddingSharder, - ListOfSparseFeaturesList, + EmbeddingComputeKernel, + FeatureShardingMixIn, + GroupedEmbeddingConfig, + InputDistOutputs, + KJTList, + ListOfKJTList, ShardingType, - SparseFeatures, - SparseFeaturesList, ) -from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext +from torchrec.distributed.fused_params import ( + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + FUSED_PARAM_REGISTER_TBE_BOOL, + get_tbes_to_register_from_iterable, + is_fused_param_quant_state_dict_split_scale_bias, + is_fused_param_register_tbe, +) +from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.mc_modules import ( + InferManagedCollisionCollectionSharder, + ShardedMCCRemapper, + ShardedQuantManagedCollisionCollection, +) +from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState +from torchrec.distributed.sharding.cw_sequence_sharding import ( + InferCwSequenceEmbeddingSharding, +) +from torchrec.distributed.sharding.rw_sequence_sharding import ( + InferRwSequenceEmbeddingSharding, +) +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) from torchrec.distributed.sharding.tw_sequence_sharding import ( InferTwSequenceEmbeddingSharding, ) -from torchrec.distributed.types import ( - Awaitable, - FeatureShardingMixIn, - LazyAwaitable, - ParameterSharding, - ShardedModule, - ShardingEnv, -) +from torchrec.distributed.types import ParameterSharding, ShardingEnv, ShardMetadata +from torchrec.distributed.utils import append_prefix from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, dtype_to_data_type, EmbeddingConfig, ) +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + QuantManagedCollisionEmbeddingCollection, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Multistreamable +torch.fx.wrap("len") +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -56,15 +103,93 @@ pass +logger: logging.Logger = logging.getLogger(__name__) + + +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + @dataclass class EmbeddingCollectionContext(Multistreamable): sharding_contexts: List[InferSequenceShardingContext] - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: for ctx in self.sharding_contexts: ctx.record_stream(stream) +class ManagedCollisionEmbeddingCollectionContext(EmbeddingCollectionContext): + + def __init__( + self, + sharding_contexts: Optional[List[SequenceShardingContext]] = None, + input_features: Optional[List[KeyedJaggedTensor]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None, + remapped_kjt: Optional[KJTList] = None, + ) -> None: + # pyre-ignore + super().__init__(sharding_contexts) + self.evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = ( + evictions_per_table + ) + self.remapped_kjt: Optional[KJTList] = remapped_kjt + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + if self.evictions_per_table: + # pyre-ignore + for value in self.evictions_per_table.values(): + if value is None: + continue + value.record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> Union[str, Tuple[str, ...]]: + """ + Returns list ofdevice type / shard if table is sharded across different device type + else reutrns single device type for the table parameter + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] + else: + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list + + +def get_device_from_sharding_infos( + emb_shard_infos: List[EmbeddingShardingInfo], +) -> Union[str, Tuple[str, ...]]: + res = list( + { + get_device_from_parameter_sharding(ps.param_sharding) + for ps in emb_shard_infos + } + ) + assert len(res) == 1, "All shards should be on the same type of device" + return res[0] + + +def get_device_for_first_shard_from_sharding_infos( + emb_shard_infos: List[EmbeddingShardingInfo], +) -> str: + device_type = get_device_from_sharding_infos(emb_shard_infos) + return device_type[0] if isinstance(device_type, tuple) else device_type + + def create_infer_embedding_sharding( sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], @@ -72,77 +197,321 @@ def create_infer_embedding_sharding( device: Optional[torch.device] = None, ) -> EmbeddingSharding[ InferSequenceShardingContext, - SparseFeaturesList, + InputDistOutputs, List[torch.Tensor], List[torch.Tensor], ]: - if sharding_type == ShardingType.TABLE_WISE.value: - return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) + device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = ( + get_device_from_sharding_infos(sharding_infos) + ) + + if device_type_from_sharding_infos in ["cuda", "mtia"]: + if sharding_type == ShardingType.TABLE_WISE.value: + return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return InferCwSequenceEmbeddingSharding(sharding_infos, env, device) + elif sharding_type == ShardingType.ROW_WISE.value: + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) + else: + raise ValueError( + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" + ) + elif device_type_from_sharding_infos == "cpu" or isinstance( + device_type_from_sharding_infos, tuple + ): + if sharding_type == ShardingType.ROW_WISE.value: + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) + elif sharding_type == ShardingType.TABLE_WISE.value: + return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) + else: + raise ValueError( + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" + ) else: - raise ValueError(f"Sharding type not supported {sharding_type}") + raise ValueError( + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" + ) -def _construct_jagged_tensors( - embeddings: torch.Tensor, - features: KeyedJaggedTensor, - need_indices: bool = False, +@torch.fx.wrap +def _fx_trec_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: + assert optional is not None, "Expected optional to be non-None Tensor" + return optional + + +@torch.fx.wrap +def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]: + return length.long().tolist() + + +@torch.fx.wrap +def _get_unbucketize_tensor_via_length_alignment( + lengths: torch.Tensor, + bucketize_length: torch.Tensor, + bucketize_permute_tensor: torch.Tensor, + bucket_mapping_tensor: torch.Tensor, +) -> torch.Tensor: + return bucketize_permute_tensor + + +def _construct_jagged_tensors_tw( + embeddings: List[torch.Tensor], + embedding_names_per_rank: List[List[str]], + features: KJTList, + need_indices: bool, ) -> Dict[str, JaggedTensor]: - # ignore cw consideration for inference now. ret: Dict[str, JaggedTensor] = {} - lengths = features.lengths().view(-1, features.stride()) - values = features.values() - length_per_key = features.length_per_key() - values_list = torch.split(values, length_per_key) if need_indices else None - embeddings_list = torch.split(embeddings, length_per_key, dim=0) - stride = features.stride() - lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) - for i, key in enumerate(features.keys()): - ret[key] = JaggedTensor( - lengths=lengths_tuple[i], - values=embeddings_list[i], - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. + for i in range(len(embedding_names_per_rank)): + embeddings_i: torch.Tensor = embeddings[i] + features_i: KeyedJaggedTensor = features[i] + + lengths = features_i.lengths().view(-1, features_i.stride()) + values = features_i.values() + length_per_key = features_i.length_per_key() + + embeddings_list = torch.split(embeddings_i, length_per_key, dim=0) + stride = features_i.stride() + lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) + if need_indices: + values_list = torch.split(values, length_per_key) + for j, key in enumerate(embedding_names_per_rank[i]): + ret[key] = JaggedTensor( + lengths=lengths_tuple[j], + values=embeddings_list[j], + weights=values_list[j], + ) + else: + for j, key in enumerate(embedding_names_per_rank[i]): + ret[key] = JaggedTensor( + lengths=lengths_tuple[j], + values=embeddings_list[j], + weights=None, + ) + return ret + + +@torch.fx.wrap +def _fx_marker_construct_jagged_tensor( + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor], +) -> JaggedTensor: + return JaggedTensor(values=values, lengths=lengths, weights=weights) + + +def _construct_jagged_tensors_rw( + embeddings: List[torch.Tensor], + feature_keys: List[str], + feature_length: torch.Tensor, + feature_indices: Optional[torch.Tensor], + need_indices: bool, + unbucketize_tensor: torch.Tensor, +) -> Dict[str, JaggedTensor]: + ret: Dict[str, JaggedTensor] = {} + unbucketized_embs = torch.concat(embeddings, dim=0).index_select( + 0, unbucketize_tensor + ) + feature_length_2d = feature_length.view(len(feature_keys), -1) + length_per_key: List[int] = _fx_trec_wrap_length_tolist( + torch.sum(feature_length_2d, dim=1) + ) + embs_split_per_key = unbucketized_embs.split(length_per_key, dim=0) + lengths_list = torch.unbind(feature_length_2d, dim=0) + values_list: List[torch.Tensor] = [] + if need_indices: + # pyre-ignore + values_list = torch.split( + _fx_trec_unwrap_optional_tensor(feature_indices), + length_per_key, + ) + for i, key in enumerate(feature_keys): + ret[key] = _fx_marker_construct_jagged_tensor( + values=embs_split_per_key[i], + lengths=lengths_list[i], weights=values_list[i] if need_indices else None, ) return ret -class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]): - def __init__( - self, - awaitables_per_sharding: List[Awaitable[List[torch.Tensor]]], - features_per_sharding: List[List[KeyedJaggedTensor]], - need_indices: bool = False, - ) -> None: - super().__init__() - self._awaitables_per_sharding: List[ - Awaitable[List[torch.Tensor]] - ] = awaitables_per_sharding - self._features_per_sharding: List[ - List[KeyedJaggedTensor] - ] = features_per_sharding - self._need_indices = need_indices - - def _wait_impl(self) -> Dict[str, JaggedTensor]: - jt_dict: Dict[str, JaggedTensor] = {} - for w_sharding, f_sharding in zip( - self._awaitables_per_sharding, - self._features_per_sharding, - ): - emb_sharding = w_sharding.wait() - for emb, f in zip(emb_sharding, f_sharding): - jt_dict.update( - _construct_jagged_tensors( - embeddings=emb, - features=f, - need_indices=self._need_indices, +@torch.fx.wrap +def _construct_jagged_tensors_cw( + embeddings: List[torch.Tensor], + features: KJTList, + embedding_names_per_rank: List[List[str]], + need_indices: bool, + features_to_permute_indices: Dict[str, torch.Tensor], + key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], +) -> Dict[str, JaggedTensor]: + ret: Dict[str, JaggedTensor] = {} + stride = features[0].stride() + lengths_lists: List[List[torch.Tensor]] = [] + embeddings_lists: List[List[torch.Tensor]] = [] + values_lists: List[List[torch.Tensor]] = [] + for i in range(len(features)): + embedding = embeddings[i] + feature = features[i] + # pyre-fixme[6]: For 1st argument expected `List[Tensor]` but got + # `Tuple[Tensor, ...]`. + lengths_lists.append(torch.unbind(feature.lengths().view(-1, stride), dim=0)) + embeddings_lists.append( + list(torch.split(embedding, feature.length_per_key(), dim=0)) + ) + if need_indices: + for i in range(len(features)): + feature = features[i] + values_lists.append( + list(torch.split(feature.values(), feature.length_per_key())) + ) + + for key, permuted_coordinate_tensor in key_to_feature_permuted_coordinates.items(): + permuted_coordinates: List[List[int]] = permuted_coordinate_tensor.tolist() + + rank0, idx_in_rank0 = permuted_coordinates[0] + ret[key] = JaggedTensor( + lengths=lengths_lists[rank0][idx_in_rank0], + values=torch.cat( + [ + embeddings_lists[rank][idx_in_rank] + for rank, idx_in_rank in permuted_coordinates + ], + dim=1, + ), + weights=values_lists[rank0][idx_in_rank0] if need_indices else None, + ) + return ret + + +@torch.fx.wrap +def input_dist_permute( + features: KeyedJaggedTensor, + features_order: List[int], + features_order_tensor: torch.Tensor, +) -> KeyedJaggedTensor: + return features.permute( + features_order, + features_order_tensor, + ) + + +def _construct_jagged_tensors( + sharding_type: str, + embeddings: List[torch.Tensor], + features: KJTList, + embedding_names: List[str], + embedding_names_per_rank: List[List[str]], + features_before_input_dist: KeyedJaggedTensor, + need_indices: bool, + rw_unbucketize_tensor: Optional[torch.Tensor], + rw_bucket_mapping_tensor: Optional[torch.Tensor], + rw_feature_length_after_bucketize: Optional[torch.Tensor], + cw_features_to_permute_indices: Dict[str, torch.Tensor], + key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], +) -> Dict[str, JaggedTensor]: + + # Validating sharding type and parameters + valid_sharding_types = [ + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_WISE.value, + ] + if sharding_type not in valid_sharding_types: + raise ValueError(f"Unknown sharding type {sharding_type}") + + if sharding_type == ShardingType.ROW_WISE.value and rw_unbucketize_tensor is None: + raise ValueError("rw_unbucketize_tensor is required for row-wise sharding") + + if sharding_type == ShardingType.ROW_WISE.value: + features_before_input_dist_length = _fx_trec_get_feature_length( + features_before_input_dist, embedding_names + ) + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], ) ) - return jt_dict + + return _construct_jagged_tensors_rw( + input_embeddings, + embedding_names, + features_before_input_dist_length, + features_before_input_dist.values() if need_indices else None, + need_indices, + _get_unbucketize_tensor_via_length_alignment( + features_before_input_dist_length, + rw_feature_length_after_bucketize, + rw_unbucketize_tensor, + rw_bucket_mapping_tensor, + ), + ) + + elif sharding_type == ShardingType.COLUMN_WISE.value: + return _construct_jagged_tensors_cw( + embeddings, + features, + embedding_names_per_rank, + need_indices, + cw_features_to_permute_indices, + key_to_feature_permuted_coordinates, + ) + else: # sharding_type == ShardingType.TABLE_WISE.value + return _construct_jagged_tensors_tw( + embeddings, embedding_names_per_rank, features, need_indices + ) + + +# Wrap the annotation in a separate function with input parameter so that it won't be dropped during symbolic trace. +# Please note the input parameter is necessary, though is not used, otherwise this function will be optimized. +@torch.fx.has_side_effect +@torch.fx.wrap +def annotate_embedding_names( + embedding_names: List[str], + dummy: List[List[torch.Tensor]], +) -> List[str]: + return torch.jit.annotate(List[str], embedding_names) + + +@torch.fx.wrap +def format_embedding_names_per_rank_per_sharding( + embedding_names_per_rank_per_sharding: List[List[List[str]]], + dummy: List[List[torch.Tensor]], +) -> List[List[List[str]]]: + annotated_embedding_names_per_rank_per_sharding: List[List[List[str]]] = [] + for embedding_names_per_rank in embedding_names_per_rank_per_sharding: + annotated_embedding_names_per_rank: List[List[str]] = [] + for embedding_names in embedding_names_per_rank: + annotated_embedding_names_per_rank.append( + annotate_embedding_names(embedding_names, dummy) + ) + annotated_embedding_names_per_rank_per_sharding.append( + annotated_embedding_names_per_rank + ) + return annotated_embedding_names_per_rank_per_sharding class ShardedQuantEmbeddingCollection( - ShardedModule[ - ListOfSparseFeaturesList, + ShardedQuantEmbeddingModuleState[ + ListOfKJTList, List[List[torch.Tensor]], Dict[str, JaggedTensor], EmbeddingCollectionContext, @@ -156,35 +525,93 @@ def __init__( self, module: QuantEmbeddingCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], - env: ShardingEnv, + # TODO: Consolidate to use Dict[str, ShardingEnv] + env: Union[ + ShardingEnv, Dict[str, ShardingEnv] + ], # Support hybrid sharding for DI fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, ) -> None: super().__init__() self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs() - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( + self._sharding_type_device_group_to_sharding_infos: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] + ] = create_sharding_infos_by_sharding_device_group( module, table_name_to_parameter_sharding, fused_params ) - self._sharding_type_to_sharding: Dict[ - str, + + self._sharding_type_device_group_to_sharding: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], EmbeddingSharding[ InferSequenceShardingContext, - SparseFeaturesList, + InputDistOutputs, List[torch.Tensor], List[torch.Tensor], ], ] = { - sharding_type: create_infer_embedding_sharding( - sharding_type, embedding_confings, env + (sharding_type, device_group): create_infer_embedding_sharding( + sharding_type, + embedding_configs, + ( + env + if not isinstance(env, Dict) + else env[ + get_device_for_first_shard_from_sharding_infos( + embedding_configs + ) + ] + ), + device if get_propogate_device() else None, ) - for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items() + for ( + sharding_type, + device_group, + ), embedding_configs in self._sharding_type_device_group_to_sharding_infos.items() } + self._embedding_dim: int = module.embedding_dim() + self._local_embedding_dim: int = self._embedding_dim + self._all_embedding_names: Set[str] = set() + self._embedding_names_per_sharding: List[List[str]] = [] + self._embedding_names_per_rank_per_sharding: List[List[List[str]]] = [] + for sharding in self._sharding_type_device_group_to_sharding.values(): + self._embedding_names_per_sharding.append(sharding.embedding_names()) + self._all_embedding_names.update(sharding.embedding_names()) + self._embedding_names_per_rank_per_sharding.append( + sharding.embedding_names_per_rank() + ) + self._features_to_permute_indices: Dict[str, torch.Tensor] = {} + self._key_to_feature_permuted_coordinates_per_sharding: List[ + Dict[str, torch.Tensor] + ] = [{} for i in range(len(self._embedding_names_per_rank_per_sharding))] - self._input_dists: List[nn.Module] = [] + for ( + sharding_type, + device_group, + ) in self._sharding_type_device_group_to_sharding.keys(): + if sharding_type == ShardingType.COLUMN_WISE.value: + sharding = self._sharding_type_device_group_to_sharding[ + (sharding_type, device_group) + ] + # CW partition must be same for all CW sharded parameters + self._local_embedding_dim = cast( + ShardMetadata, sharding.embedding_shard_metadata()[0] + ).shard_sizes[1] + self._features_to_permute_indices = ( + self._generate_permute_indices_per_feature( + module.embedding_configs(), table_name_to_parameter_sharding + ) + ) + + self._generate_permute_coordinates_per_feature_per_sharding() + + self._device = device self._lookups: List[nn.Module] = [] - self._create_lookups(fused_params) - self._output_dists: List[nn.Module] = [] + self._create_lookups(fused_params, device) + + # Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime). + self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList() self._feature_splits: List[int] = [] self._features_order: List[int] = [] @@ -195,140 +622,354 @@ def __init__( self._embedding_dim: int = module.embedding_dim() self._need_indices: bool = module.need_indices() - # This provides consistency between this class and the EmbeddingBagCollection's - # nn.Module API calls (state_dict, named_modules, etc) - # Currently, Sharded Quant EC only uses TW sharding, and returns non-sharded tensors as part of state dict - # TODO - revisit if we state_dict can be represented as sharded tensor - self.embeddings: nn.ModuleDict = nn.ModuleDict() - for table in self._embedding_configs: - self.embeddings[table.name] = torch.nn.Module() + self._fused_params = fused_params + + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] = ( + get_tbes_to_register_from_iterable(self._lookups) + ) + + self._tbes_configs: Dict[ + IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig + ] = tbes + + # Optional registration of TBEs for model post processing utilities + if is_fused_param_register_tbe(fused_params): + self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(tbes.keys()) + + quant_state_dict_split_scale_bias = ( + is_fused_param_quant_state_dict_split_scale_bias(fused_params) + ) + + if quant_state_dict_split_scale_bias: + self._initialize_torch_state( + tbes=tbes, + table_name_to_parameter_sharding=table_name_to_parameter_sharding, + tables_weights_prefix="embeddings", + ) + else: + assert not isinstance( + env, Dict + ), "CPU sharding currently only support RW sharding where split scale and bias is required" + + table_wise_sharded_only: bool = all( + sharding_type == ShardingType.TABLE_WISE.value + for ( + sharding_type, + _, + ) in self._sharding_type_device_group_to_sharding.keys() + ) + assert ( + table_wise_sharded_only + ), "ROW_WISE,COLUMN_WISE shardings can be used only in 'quant_state_dict_split_scale_bias' mode, specify fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS]=True to __init__ argument" + + self.embeddings: nn.ModuleDict = nn.ModuleDict() + for table in self._embedding_configs: + self.embeddings[table.name] = torch.nn.Module() + + for _sharding_type, lookup in zip( + self._sharding_type_device_group_to_sharding.keys(), self._lookups + ): + lookup_state_dict = lookup.state_dict() + for key in lookup_state_dict: + if key.endswith(".weight"): + table_name = key[: -len(".weight")] + self.embeddings[table_name].register_buffer( + "weight", lookup_state_dict[key] + ) + + def tbes_configs( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return self._tbes_configs + + def sharding_type_device_group_to_sharding_infos( + self, + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: + return self._sharding_type_device_group_to_sharding_infos - for _sharding_type, lookup in zip( - self._sharding_type_to_sharding.keys(), self._lookups + def embedding_configs(self) -> List[EmbeddingConfig]: + return self._embedding_configs + + def _generate_permute_indices_per_feature( + self, + embedding_configs: List[EmbeddingConfig], + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} + shared_feature: Dict[str, bool] = {} + for table in embedding_configs: + for feature_name in table.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + + for table in embedding_configs: + sharding = table_name_to_parameter_sharding[table.name] + if sharding.sharding_type != ShardingType.COLUMN_WISE.value: + continue + ranks = cast(List[int], sharding.ranks) + rank_to_indices = defaultdict(deque) + for i, rank in enumerate(sorted(ranks)): + rank_to_indices[rank].append(i) + permute_indices = [rank_to_indices[rank].popleft() for rank in ranks] + tensor = torch.tensor(permute_indices, dtype=torch.int64) + for feature_name in table.feature_names: + if shared_feature[feature_name]: + ret[feature_name + "@" + table.name] = tensor + else: + ret[feature_name] = tensor + return ret + + def _generate_permute_coordinates_per_feature_per_sharding( + self, + ) -> None: + key_to_feature_permuted_coordinates_per_sharding: List[ + Dict[str, List[Tuple[int, int]]] + ] = [{} for i in range(len(self._embedding_names_per_rank_per_sharding))] + + for idx, embedding_names_per_rank in enumerate( + self._embedding_names_per_rank_per_sharding ): - lookup_state_dict = lookup.state_dict() - for key in lookup_state_dict: - if not key.endswith(".weight"): - continue - table_name = key[: -len(".weight")] - # Register as buffer because this is an inference model, and can potentially use uint8 types. - self.embeddings[table_name].register_buffer( - "weight", lookup_state_dict[key] + for rank, embedding_names in enumerate(embedding_names_per_rank): + for idx_in_rank, embedding_name in enumerate(embedding_names): + if ( + embedding_name + not in key_to_feature_permuted_coordinates_per_sharding[idx] + ): + key_to_feature_permuted_coordinates_per_sharding[idx][ + embedding_name + ] = torch.jit.annotate(List[Tuple[int, int]], []) + key_to_feature_permuted_coordinates_per_sharding[idx][ + embedding_name + ].append((rank, idx_in_rank)) + + for ( + key, + coordinates, + ) in key_to_feature_permuted_coordinates_per_sharding[idx].items(): + permuted_coordinates: List[Tuple[int, int]] = coordinates + + if key in self._features_to_permute_indices: + permuted_coordinates = [(-1, -1)] * len(coordinates) + permute_indices: List[int] = self._features_to_permute_indices[ + key + ].tolist() + for i, permute_idx in enumerate(permute_indices): + permuted_coordinates[i] = coordinates[permute_idx] + self._key_to_feature_permuted_coordinates_per_sharding[idx][key] = ( + torch.tensor(permuted_coordinates) ) - def _create_input_dist( + def _create_lookups( self, - input_feature_names: List[str], - device: torch.device, + fused_params: Optional[Dict[str, Any]], + device: Optional[torch.device] = None, ) -> None: - feature_names: List[str] = [] - self._feature_splits: List[int] = [] - for sharding in self._sharding_type_to_sharding.values(): - self._input_dists.append(sharding.create_input_dist()) - feature_names.extend(sharding.id_list_feature_names()) - self._feature_splits.append(len(sharding.id_list_feature_names())) - self._features_order: List[int] = [] - for f in feature_names: - self._features_order.append(input_feature_names.index(f)) - self._features_order = ( - [] - if self._features_order == list(range(len(self._features_order))) - else self._features_order - ) - self.register_buffer( - "_features_order_tensor", - torch.tensor(self._features_order, device=device, dtype=torch.int32), - ) - - def _create_lookups(self, fused_params: Optional[Dict[str, Any]]) -> None: - for sharding in self._sharding_type_to_sharding.values(): - self._lookups.append(sharding.create_lookup(fused_params=fused_params)) + for sharding in self._sharding_type_device_group_to_sharding.values(): + self._lookups.append( + sharding.create_lookup(fused_params=fused_params, device=device) + ) def _create_output_dist( self, device: Optional[torch.device] = None, ) -> None: - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._sharding_type_device_group_to_sharding.values(): self._output_dists.append(sharding.create_output_dist(device)) - # pyre-ignore [3, 14] + # pyre-ignore [14] + # pyre-ignore def input_dist( self, ctx: EmbeddingCollectionContext, features: KeyedJaggedTensor, - ) -> Awaitable[Any]: + ) -> ListOfKJTList: if self._has_uninitialized_input_dist: - self._create_input_dist( + # pyre-fixme[16]: `ShardedQuantEmbeddingCollection` has no attribute + # `_input_dist`. + self._input_dist = ShardedQuantEcInputDist( input_feature_names=features.keys() if features is not None else [], - device=features.device(), + sharding_type_device_group_to_sharding=self._sharding_type_device_group_to_sharding, + device=self._device, + feature_device=features.device(), ) self._has_uninitialized_input_dist = False if self._has_uninitialized_output_dist: self._create_output_dist(features.device()) self._has_uninitialized_output_dist = False + + ( + input_dist_result_list, + features_by_sharding, + unbucketize_permute_tensor_list, + bucket_mapping_tensor_list, + bucketized_length_list, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + ) = self._input_dist(features) + with torch.no_grad(): - features_by_sharding = [] - if self._features_order: - features = features.permute( - self._features_order, - # pyre-ignore [6] - self._features_order_tensor, - ) - features_by_sharding = features.split( - self._feature_splits, - ) - # save input splits and output splits in sharding context which - # will be reused in sequence embedding all2all - awaitables = [] - for module, features in zip(self._input_dists, features_by_sharding): - tensor_awaitable = module( - SparseFeatures( - id_list_features=features, - id_score_list_features=None, + for i in range(len(self._sharding_type_device_group_to_sharding)): + + ctx.sharding_contexts.append( + InferSequenceShardingContext( + features=input_dist_result_list[i], + features_before_input_dist=features_by_sharding[i], + unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], + bucket_mapping_tensor=bucket_mapping_tensor_list[i], + bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) - ).wait() # a dummy wait since now length indices comm is splited - awaitables.append(tensor_awaitable) - return ListOfSparseFeaturesListAwaitable(awaitables) + ) + return input_dist_result_list + + def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: + return ( + self._local_embedding_dim + if sharding_type == ShardingType.COLUMN_WISE.value + else self._embedding_dim + ) def compute( - self, ctx: EmbeddingCollectionContext, dist_input: ListOfSparseFeaturesList + self, ctx: EmbeddingCollectionContext, dist_input: ListOfKJTList ) -> List[List[torch.Tensor]]: ret: List[List[torch.Tensor]] = [] - for lookup, features in zip( - self._lookups, - dist_input, - ): - ctx.sharding_contexts.append( - InferSequenceShardingContext( - features=[feature.id_list_features for feature in features], - ) - ) - ret.append([o.view(-1, self._embedding_dim) for o in lookup(features)]) + + # for lookup, features in zip(self._lookups, dist_input): + for i in range(len(self._lookups)): + lookup = self._lookups[i] + features = dist_input[i] + ret.append(lookup.forward(features)) return ret + # pyre-ignore def output_dist( self, ctx: EmbeddingCollectionContext, output: List[List[torch.Tensor]] - ) -> LazyAwaitable[Dict[str, JaggedTensor]]: - awaitables_per_sharding: List[Awaitable[List[torch.Tensor]]] = [] - features_per_sharding: List[List[KeyedJaggedTensor]] = [] - for odist, embeddings, sharding_ctx in zip( + ) -> Dict[str, JaggedTensor]: + emb_per_sharding: List[List[torch.Tensor]] = [] + features_before_input_dist_per_sharding: List[KeyedJaggedTensor] = [] + features_per_sharding: List[KJTList] = [] + unbucketize_tensors: List[Optional[torch.Tensor]] = [] + bucket_mapping_tensors: List[Optional[torch.Tensor]] = [] + bucketized_lengths: List[Optional[torch.Tensor]] = [] + for sharding_output_dist, embeddings, sharding_ctx in zip( self._output_dists, output, ctx.sharding_contexts, ): - awaitables_per_sharding.append(odist(embeddings, sharding_ctx)) + sharding_output_dist_res: List[torch.Tensor] = sharding_output_dist.forward( + embeddings, sharding_ctx + ) + emb_per_sharding.append(sharding_output_dist_res) features_per_sharding.append(sharding_ctx.features) - return EmbeddingCollectionAwaitable( - awaitables_per_sharding=awaitables_per_sharding, + unbucketize_tensors.append( + sharding_ctx.unbucketize_permute_tensor + if sharding_ctx.unbucketize_permute_tensor is not None + else None + ) + bucket_mapping_tensors.append( + sharding_ctx.bucket_mapping_tensor + if sharding_ctx.bucket_mapping_tensor is not None + else None + ) + bucketized_lengths.append( + sharding_ctx.bucketized_length + if sharding_ctx.bucketized_length is not None + else None + ) + features_before_input_dist_per_sharding.append( + # pyre-ignore + sharding_ctx.features_before_input_dist + ) + return self.output_jt_dict( + emb_per_sharding=emb_per_sharding, features_per_sharding=features_per_sharding, - need_indices=self._need_indices, + features_before_input_dist_per_sharding=features_before_input_dist_per_sharding, + unbucketize_tensors=unbucketize_tensors, + bucket_mapping_tensors=bucket_mapping_tensors, + bucketized_lengths=bucketized_lengths, ) + def output_jt_dict( + self, + emb_per_sharding: List[List[torch.Tensor]], + features_per_sharding: List[KJTList], + features_before_input_dist_per_sharding: List[KeyedJaggedTensor], + unbucketize_tensors: List[Optional[torch.Tensor]], + bucket_mapping_tensors: List[Optional[torch.Tensor]], + bucketized_lengths: List[Optional[torch.Tensor]], + ) -> Dict[str, JaggedTensor]: + jt_dict_res: Dict[str, JaggedTensor] = {} + for ( + (sharding_type, device_type), + emb_sharding, + features_sharding, + embedding_names, + embedding_names_per_rank, + features_before_input_dist, + unbucketize_tensor, + bucket_mapping_tensor, + bucketized_length, + key_to_feature_permuted_coordinates, + ) in zip( + self._sharding_type_device_group_to_sharding.keys(), + emb_per_sharding, + features_per_sharding, + self._embedding_names_per_sharding, + self._embedding_names_per_rank_per_sharding, + features_before_input_dist_per_sharding, + unbucketize_tensors, + bucket_mapping_tensors, + bucketized_lengths, + self._key_to_feature_permuted_coordinates_per_sharding, + ): + jt_dict = _construct_jagged_tensors( + sharding_type=sharding_type, + embeddings=emb_sharding, + features=features_sharding, + embedding_names=embedding_names, + embedding_names_per_rank=embedding_names_per_rank, + features_before_input_dist=features_before_input_dist, + need_indices=self._need_indices, + rw_unbucketize_tensor=( + # this is batching hint for constructing alignment sparse features for batching + _fx_trec_unwrap_optional_tensor(unbucketize_tensor) + if sharding_type == ShardingType.ROW_WISE.value + else None + ), + rw_bucket_mapping_tensor=( + _fx_trec_unwrap_optional_tensor(bucket_mapping_tensor) + if sharding_type == ShardingType.ROW_WISE.value + else None + ), + rw_feature_length_after_bucketize=( + _fx_trec_unwrap_optional_tensor(bucketized_length) + if sharding_type == ShardingType.ROW_WISE.value + else None + ), + cw_features_to_permute_indices=self._features_to_permute_indices, + key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, + ) + for embedding_name in embedding_names: + jt_dict_res[embedding_name] = jt_dict[embedding_name] + + return jt_dict_res + + # pyre-ignore def compute_and_output_dist( - self, ctx: EmbeddingCollectionContext, input: ListOfSparseFeaturesList - ) -> LazyAwaitable[Dict[str, JaggedTensor]]: + self, ctx: EmbeddingCollectionContext, input: ListOfKJTList + ) -> Dict[str, JaggedTensor]: return self.output_dist(ctx, self.compute(ctx, input)) + # pyre-ignore + def forward(self, *input, **kwargs) -> Dict[str, JaggedTensor]: + ctx = self.create_context() + dist_input = self.input_dist(ctx, *input, **kwargs) + return self.compute_and_output_dist(ctx, dist_input) + def copy(self, device: torch.device) -> nn.Module: if self._has_uninitialized_output_dist: self._create_output_dist(device) @@ -339,9 +980,11 @@ def create_context(self) -> EmbeddingCollectionContext: return EmbeddingCollectionContext(sharding_contexts=[]) @property - def shardings(self) -> Dict[str, FeatureShardingMixIn]: + def shardings( + self, + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]: # pyre-ignore [7] - return self._sharding_type_to_sharding + return self._sharding_type_device_group_to_sharding class QuantEmbeddingCollectionSharder( @@ -355,15 +998,403 @@ def shard( self, module: QuantEmbeddingCollection, params: Dict[str, ParameterSharding], - env: ShardingEnv, + env: Union[ShardingEnv, Dict[str, ShardingEnv]], device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedQuantEmbeddingCollection: fused_params = self.fused_params if self.fused_params else {} fused_params["output_dtype"] = data_type_to_sparse_type( dtype_to_data_type(module.output_dtype()) ) - return ShardedQuantEmbeddingCollection(module, params, env, fused_params) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + return ShardedQuantEmbeddingCollection( + module=module, + table_name_to_parameter_sharding=params, + env=env, + fused_params=fused_params, + device=device, + ) @property def module_type(self) -> Type[QuantEmbeddingCollection]: return QuantEmbeddingCollection + + +class ShardedQuantEcInputDist(torch.nn.Module): + """ + This module implements distributed inputs of a ShardedQuantEmbeddingCollection. + + Args: + input_feature_names (List[str]): EmbeddingCollection feature names. + sharding_type_to_sharding (Dict[ + str, + EmbeddingSharding[ + InferSequenceShardingContext, + KJTList, + List[torch.Tensor], + List[torch.Tensor], + ], + ]): map from sharding type to EmbeddingSharding. + device (Optional[torch.device]): default compute device. + feature_device (Optional[torch.device]): runtime feature device. + + Example:: + + sqec_input_dist = ShardedQuantEcInputDist( + sharding_type_to_sharding={ + ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding( + [], + ShardingEnv( + world_size=2, + rank=0, + pg=0, + ), + torch.device("cpu") + ) + }, + device=torch.device("cpu"), + ) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + sqec_input_dist(features) + """ + + def __init__( + self, + input_feature_names: List[str], + sharding_type_device_group_to_sharding: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], + EmbeddingSharding[ + InferSequenceShardingContext, + InputDistOutputs, + List[torch.Tensor], + List[torch.Tensor], + ], + ], + device: Optional[torch.device] = None, + feature_device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._sharding_type_device_group_to_sharding = ( + sharding_type_device_group_to_sharding + ) + self._input_dists = torch.nn.ModuleList([]) + self._feature_splits: List[int] = [] + self._features_order: List[int] = [] + + feature_names: List[str] = [] + for sharding in sharding_type_device_group_to_sharding.values(): + self._input_dists.append(sharding.create_input_dist(device=device)) + feature_names.extend(sharding.feature_names()) + self._feature_splits.append(len(sharding.feature_names())) + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + + self._features_order = ( + [] + if self._features_order == list(range(len(self._features_order))) + else self._features_order + ) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=feature_device, dtype=torch.int32 + ), + persistent=False, + ) + + def forward(self, features: KeyedJaggedTensor) -> Tuple[ + List[KJTList], + List[KeyedJaggedTensor], + List[Optional[torch.Tensor]], + List[Optional[torch.Tensor]], + List[Optional[torch.Tensor]], + ]: + + with torch.no_grad(): + ret: List[KJTList] = [] + unbucketize_permute_tensor = [] + bucket_mapping_tensor = [] + bucketized_lengths = [] + if self._features_order: + features = input_dist_permute( + features, + self._features_order, + self._features_order_tensor, + ) + features_by_sharding = ( + [features] + if len(self._feature_splits) == 1 + else features.split(self._feature_splits) + ) + + for i in range(len(self._input_dists)): + input_dist = self._input_dists[i] + input_dist_result = input_dist(features_by_sharding[i]) + + ret.append(input_dist_result.features) + + unbucketize_permute_tensor.append( + input_dist_result.unbucketize_permute_tensor + ) + bucket_mapping_tensor.append(input_dist_result.bucket_mapping_tensor) + bucketized_lengths.append(input_dist_result.bucketized_length) + + return ( + ret, + features_by_sharding, + unbucketize_permute_tensor, + bucket_mapping_tensor, + bucketized_lengths, + ) + + +class ShardedMCECLookup(torch.nn.Module): + """ + This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingCollection. + + Args: + managed_collision_collection (ShardedQuantManagedCollisionCollection): managed collision collection + lookups (List[nn.Module]): embedding lookups + + Example:: + + """ + + def __init__( + self, + sharding: int, + rank: int, + mcc_remapper: ShardedMCCRemapper, + ec_lookup: nn.Module, + ) -> None: + super().__init__() + self._sharding = sharding + self._rank = rank + self._mcc_remapper = mcc_remapper + self._ec_lookup = ec_lookup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + remapped_kjt = self._mcc_remapper(features) + return self._ec_lookup(remapped_kjt) + + +class ShardedQuantManagedCollisionEmbeddingCollection(ShardedQuantEmbeddingCollection): + def __init__( + self, + module: QuantManagedCollisionEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + mc_sharder: InferManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, table_name_to_parameter_sharding, env, fused_params, device + ) + + self._device = device + self._env = env + + # TODO: This is a hack since _embedding_module doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + # we're using ec input_dist directly, so this cannot be escaped. + # self._has_uninitialized_input_dist = False + embedding_shardings = list( + self._sharding_type_device_group_to_sharding.values() + ) + + self._managed_collision_collection: ShardedQuantManagedCollisionCollection = ( + mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + # pyre-ignore + embedding_shardings=embedding_shardings, + ) + ) + self._return_remapped_features: bool = module._return_remapped_features + self._create_mcec_lookups() + + def _create_mcec_lookups(self) -> None: + mcec_lookups: List[nn.ModuleList] = [] + mcc_remappers: List[List[ShardedMCCRemapper]] = ( + self._managed_collision_collection.create_mcc_remappers() + ) + for sharding in range( + len(self._managed_collision_collection._embedding_shardings) + ): + ec_sharding_lookups = self._lookups[sharding] + sharding_mcec_lookups: List[ShardedMCECLookup] = [] + for j, ec_lookup in enumerate( + ec_sharding_lookups._embedding_lookups_per_rank # pyre-ignore + ): + sharding_mcec_lookups.append( + ShardedMCECLookup( + sharding, + j, + mcc_remappers[sharding][j], + ec_lookup, + ) + ) + mcec_lookups.append(nn.ModuleList(sharding_mcec_lookups)) + self._mcec_lookup: nn.ModuleList = nn.ModuleList(mcec_lookups) + + # For consistency with ShardedManagedCollisionEmbeddingCollection + @property + def _embedding_collection(self) -> ShardedQuantEmbeddingCollection: + return cast(ShardedQuantEmbeddingCollection, self) + + def input_dist( + self, + ctx: EmbeddingCollectionContext, + features: KeyedJaggedTensor, + ) -> ListOfKJTList: + # TODO: resolve incompatiblity with different contexts + if self._has_uninitialized_output_dist: + self._create_output_dist(features.device()) + self._has_uninitialized_output_dist = False + + return self._managed_collision_collection.input_dist( + # pyre-fixme [6] + ctx, + features, + ) + + def compute( + self, + ctx: ShrdCtx, + dist_input: ListOfKJTList, + ) -> List[List[torch.Tensor]]: + ret: List[List[torch.Tensor]] = [] + for i in range(len(self._managed_collision_collection._embedding_shardings)): + dist_input_i = dist_input[i] + lookups = self._mcec_lookup[i] + sharding_ret: List[torch.Tensor] = [] + for j, lookup in enumerate(lookups): + rank_ret = lookup( + features=dist_input_i[j], + ) + sharding_ret.append(rank_ret) + ret.append(sharding_ret) + return ret + + # pyre-ignore + def output_dist( + self, + ctx: ShrdCtx, + output: List[List[torch.Tensor]], + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + + # pyre-ignore [6] + ebc_out = super().output_dist(ctx, output) + + kjt_out: Optional[KeyedJaggedTensor] = None + + return ebc_out, kjt_out + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + yield append_prefix(prefix, fqn) + for fqn, _ in self.named_buffers(): + yield append_prefix(prefix, fqn) + + +class QuantManagedCollisionEmbeddingCollectionSharder( + BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection] +): + """ + This implementation uses non-fused EmbeddingCollection + """ + + def __init__( + self, + e_sharder: QuantEmbeddingCollectionSharder, + mc_sharder: InferManagedCollisionCollectionSharder, + ) -> None: + super().__init__() + self._e_sharder: QuantEmbeddingCollectionSharder = e_sharder + self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder + + def shardable_parameters( + self, module: QuantManagedCollisionEmbeddingCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._e_sharder.shardable_parameters(module) + + def compute_kernels( + self, + sharding_type: str, + compute_device_type: str, + ) -> List[str]: + return [ + EmbeddingComputeKernel.QUANT.value, + ] + + def sharding_types(self, compute_device_type: str) -> List[str]: + return list( + set.intersection( + set(self._e_sharder.sharding_types(compute_device_type)), + set(self._mc_sharder.sharding_types(compute_device_type)), + ) + ) + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._e_sharder.fused_params + + def shard( + self, + module: QuantManagedCollisionEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantManagedCollisionEmbeddingCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + False, + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + return ShardedQuantManagedCollisionEmbeddingCollection( + module, + params, + self._mc_sharder, + env, + fused_params, + device, + ) + + @property + def module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]: + return QuantManagedCollisionEmbeddingCollection diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 17bbea4be..cc324d52a 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import logging from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -12,7 +14,7 @@ import torch import torch.distributed as dist from fbgemm_gpu.split_embedding_configs import SparseType -from fbgemm_gpu.split_table_batched_embeddings_ops import ( +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( EmbeddingLocation, IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, @@ -29,6 +31,15 @@ compute_kernel_to_embedding_location, GroupedEmbeddingConfig, ) +from torchrec.distributed.fused_params import ( + fused_param_bounds_check_mode, + fused_param_lengths_to_offsets_lookup, + is_fused_param_quant_state_dict_split_scale_bias, + is_fused_param_register_tbe, + tbe_fused_params, + TBEToRegisterMixIn, +) +from torchrec.distributed.types import BoundsCheckMode from torchrec.distributed.utils import append_prefix from torchrec.modules.embedding_configs import ( DATA_TYPE_NUM_BITS, @@ -108,13 +119,120 @@ def _quantize_weight( return quant_weight_list -class QuantBatchedEmbeddingBag(BaseBatchedEmbeddingBag): +def _get_runtime_device( + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, +) -> torch.device: + index: int = 0 if shard_index is None else shard_index + if device is not None and device.type != "meta": + return device + else: + return ( + torch.device("cpu") + if all( + ( + table.local_metadata is not None + and table.local_metadata.placement is not None + and table.local_metadata.placement.device().type == "cpu" + ) + or ( + table.global_metadata is not None + and len(table.global_metadata.shards_metadata) + and table.global_metadata.shards_metadata[index].placement + is not None + # pyre-ignore: Undefined attribute [16] + and table.global_metadata.shards_metadata[index] + .placement.device() + .type + == "cpu" + ) + for table in config.embedding_tables + ) + else torch.device("cuda") + ) + + +@torch.fx.wrap +def _unwrap_kjt( + features: KeyedJaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Here it should always follow cuda path, runtime device cannot be meta + indices = features.values() + offsets = features.offsets() + return ( + indices.int(), # currently only support int32 indices + offsets.int(), + features.weights_or_none(), + ) + + +def _unwrap_kjt_for_cpu( + features: KeyedJaggedTensor, weighted: bool +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + indices = features.values() + offsets = features.offsets().type(indices.dtype) + if weighted: + return indices, offsets, features.weights() + else: + return indices, offsets, None + + +@torch.fx.wrap +def _unwrap_kjt_lengths( + features: KeyedJaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + indices = features.values() + lengths = features.lengths() + return ( + indices.int(), + lengths.int(), + features.weights_or_none(), + ) + + +@torch.fx.wrap +def _unwrap_optional_tensor( + tensor: Optional[torch.Tensor], +) -> torch.Tensor: + # Typing for TorchScript + assert tensor is not None + return tensor + + +class IntNBitTableBatchedEmbeddingBagsCodegenWithLength( + IntNBitTableBatchedEmbeddingBagsCodegen +): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # pyre-ignore Inconsistent override [14] + def forward( + self, + indices: torch.Tensor, + lengths: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self._forward_impl( + indices=indices, + offsets=(torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)), + per_sample_weights=per_sample_weights, + ) + + +class QuantBatchedEmbeddingBag( + BaseBatchedEmbeddingBag[ + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ], + TBEToRegisterMixIn, +): def __init__( self, config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -126,28 +244,65 @@ def __init__( ) else: managed.append(EmbeddingLocation.HOST) - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( - IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ + self._config: GroupedEmbeddingConfig = config + self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) + self._is_weighted: Optional[bool] = config.is_weighted + self._quant_state_dict_split_scale_bias: bool = ( + is_fused_param_quant_state_dict_split_scale_bias(fused_params) + ) + bounds_check_mode: Optional[BoundsCheckMode] = fused_param_bounds_check_mode( + fused_params + ) + + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) + # 16 for CUDA, 1 for others like CPU and MTIA. + self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 + embedding_specs = [] + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, + ): + embedding_specs.append( + ( + table.name, + local_rows, ( - "", - local_rows, - table.embedding_dim, - data_type_to_sparse_type(config.data_type), - location, - ) - for local_rows, table, location in zip( - self._local_rows, config.embedding_tables, managed - ) - ], - device=device, - pooling_mode=self._pooling, - feature_table_map=self._feature_table_map, - row_alignment=16, - **(fused_params or {}), + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(table.data_type), + location, + ) ) + + self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params) + + if self.lengths_to_tbe: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength + else: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz( + embedding_specs=embedding_specs, + device=device, + pooling_mode=self._pooling, + feature_table_map=self._feature_table_map, + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + bounds_check_mode=( + bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING + ), + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) - if device is not None and device.type != "meta": + if device is not None: self._emb_module.initialize_weights() def init_parameters(self) -> None: @@ -159,11 +314,54 @@ def emb_module( ) -> IntNBitTableBatchedEmbeddingBagsCodegen: return self._emb_module + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return {self._emb_module: self._config} + + def _emb_module_forward( + self, + indices: torch.Tensor, + lengths_or_offsets: torch.Tensor, + weights: Optional[torch.Tensor], + ) -> torch.Tensor: + kwargs = {"indices": indices} + + if self.lengths_to_tbe: + kwargs["lengths"] = lengths_or_offsets + else: + kwargs["offsets"] = lengths_or_offsets + + if self._is_weighted: + kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights) + + if self._emb_module_registered: + # Conditional call of .forward function for FX: + # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) + # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) + # For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged. + return self._emb_module(**kwargs) + else: + return self._emb_module.forward(**kwargs) + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: - return self.emb_module( - indices=features.values().int(), - offsets=features.offsets().int(), - per_sample_weights=features.weights_or_none(), + # Important: _unwrap_kjt regex for FX tracing TAGing + lengths, offsets = None, None + if self._runtime_device.type == "cpu": + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) + else: + indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu( + features, self._config.is_weighted + ) + else: + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) + else: + indices, offsets, per_sample_weights = _unwrap_kjt(features) + + return self._emb_module_forward( + indices, lengths if lengths is not None else offsets, per_sample_weights ) def named_buffers( @@ -172,53 +370,76 @@ def named_buffers( assert ( remove_duplicate ), "remove_duplicate=False not supported in QuantBatchedEmbeddingBag.named_split_embedding_weights" - for config, weight in zip( + for config, (weight, weight_qscale, weight_qbias) in zip( self._config.embedding_tables, - self.emb_module.split_embedding_weights(), + self.emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) + ), ): - yield append_prefix(prefix, f"{config.name}.weight"), weight[0] - - def split_embedding_weights(self) -> List[torch.Tensor]: + yield append_prefix(prefix, f"{config.name}.weight"), weight + if self._quant_state_dict_split_scale_bias: + yield append_prefix( + prefix, f"{config.name}.weight_qscale" + ), weight_qscale + yield append_prefix(prefix, f"{config.name}.weight_qbias"), weight_qbias + + def split_embedding_weights( + self, + ) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: return [ - weight - for weight, _ in self.emb_module.split_embedding_weights( - split_scale_shifts=False + (weight, qscale, qbias) + for weight, qscale, qbias in self.emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ) ] @classmethod - def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbeddingBag": + def from_float( + cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False + ) -> "QuantBatchedEmbeddingBag": assert hasattr( module, "qconfig" ), "BaseEmbedding input float module must have qconfig defined" - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`. data_type = dtype_to_data_type(module.qconfig.weight().dtype) sparse_type = data_type_to_sparse_type(data_type) + # TODO Can we simplify this with state_dict = module.state_dict()? state_dict = ( dict(module.named_split_embedding_weights()) if isinstance(module, BatchedDenseEmbeddingBag) - else dict(module.named_buffers()) + else dict(module.named_parameters()) ) device = next(iter(state_dict.values())).device config = _copy_config(module.config, data_type, sparse_type, device) ret = QuantBatchedEmbeddingBag(config=config, device=device) + # pyre-ignore quant_weight_list = _quantize_weight(state_dict, data_type) ret.emb_module.assign_embedding_weights(quant_weight_list) return ret -class QuantBatchedEmbedding(BaseBatchedEmbedding): +class QuantBatchedEmbedding( + BaseBatchedEmbedding[ + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ], + TBEToRegisterMixIn, +): def __init__( self, config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -230,28 +451,49 @@ def __init__( ) else: managed.append(EmbeddingLocation.HOST) + self._config: GroupedEmbeddingConfig = config + self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) + self._quant_state_dict_split_scale_bias: bool = ( + is_fused_param_quant_state_dict_split_scale_bias(fused_params) + ) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) + # 16 for CUDA, 1 for others like CPU and MTIA. + self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( IntNBitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ ( - "", + table.name, local_rows, - table.embedding_dim, - data_type_to_sparse_type(config.data_type), + ( + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(table.data_type), location, ) - for local_rows, table, location in zip( - self._local_rows, config.embedding_tables, managed + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, ) ], device=device, pooling_mode=PoolingMode.NONE, feature_table_map=self._feature_table_map, - row_alignment=16, - **(fused_params or {}), + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) ) - if device is not None and device.type != "meta": + if device is not None: self._emb_module.initialize_weights() @property @@ -260,49 +502,85 @@ def emb_module( ) -> IntNBitTableBatchedEmbeddingBagsCodegen: return self._emb_module - def split_embedding_weights(self) -> List[torch.Tensor]: + def get_tbes_to_register( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return {self._emb_module: self._config} + + def split_embedding_weights( + self, + ) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: return [ - weight - for weight, _ in self.emb_module.split_embedding_weights( - split_scale_shifts=False + (weight, qscale, qbias) + for weight, qscale, qbias in self.emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) ) ] def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: - return self.emb_module( - indices=features.values().int(), - offsets=features.offsets().int(), - ) + if self._runtime_device.type == "cpu": + # To distinguish fx tracing on CPU embedding. + values, offsets, _ = _unwrap_kjt_for_cpu( + features, weighted=self._config.is_weighted + ) + else: + values, offsets, _ = _unwrap_kjt(features) + + if self._emb_module_registered: + return self.emb_module( + indices=values, + offsets=offsets, + ) + else: + return self.emb_module.forward( + indices=values, + offsets=offsets, + ) def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: - for config, weight in zip( + for config, (weight, weight_qscale, weight_qbias) in zip( self._config.embedding_tables, - self.emb_module.split_embedding_weights(), + self.emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=( + 2 if self._quant_state_dict_split_scale_bias else 0 + ) + ), ): - yield append_prefix(prefix, f"{config.name}.weight"), weight[0] + yield append_prefix(prefix, f"{config.name}.weight"), weight + if self._quant_state_dict_split_scale_bias: + yield append_prefix( + prefix, f"{config.name}.weight_qscale" + ), weight_qscale + yield append_prefix(prefix, f"{config.name}.weight_qbias"), weight_qbias @classmethod - def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbedding": + def from_float( + cls, module: BaseEmbedding, use_precomputed_fake_quant: bool = False + ) -> "QuantBatchedEmbedding": assert hasattr( module, "qconfig" ), "BaseEmbedding input float module must have qconfig defined" - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`. data_type = dtype_to_data_type(module.qconfig.weight().dtype) sparse_type = data_type_to_sparse_type(data_type) + # TODO Can we simplify this with state_dict = module.state_dict()? state_dict = ( dict(module.named_split_embedding_weights()) if isinstance(module, BatchedDenseEmbedding) - else dict(module.named_buffers()) + else dict(module.named_parameters()) ) device = next(iter(state_dict.values())).device config = _copy_config(module.config, data_type, sparse_type, device) ret = QuantBatchedEmbedding(config=config, device=device) + # pyre-ignore quant_weight_list = _quantize_weight(state_dict, data_type) ret.emb_module.assign_embedding_weights(quant_weight_list) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 6001c94b1..e666841b9 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -5,66 +5,160 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, Type +# pyre-strict + +import copy +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) from torch import nn -from torch.nn.modules.module import _addindent + +from torch.distributed._shard.sharding_spec import EnumerableShardingSpec +from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingInfo, - ListOfSparseFeaturesListAwaitable, - NullShardingContext, ) from torchrec.distributed.embedding_types import ( BaseQuantEmbeddingSharder, - ListOfSparseFeaturesList, - SparseFeatures, - SparseFeaturesList, + FeatureShardingMixIn, + GroupedEmbeddingConfig, + InputDistOutputs, + KJTList, + ListOfKJTList, ) from torchrec.distributed.embeddingbag import ( - create_sharding_infos_by_sharding, - EmbeddingBagCollectionAwaitable, + construct_output_kt, + create_sharding_infos_by_sharding_device_group, +) +from torchrec.distributed.fused_params import ( + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + FUSED_PARAM_REGISTER_TBE_BOOL, + get_tbes_to_register_from_iterable, + is_fused_param_quant_state_dict_split_scale_bias, + is_fused_param_register_tbe, ) +from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState +from torchrec.distributed.sharding.cw_sharding import InferCwPooledEmbeddingSharding +from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import InferTwEmbeddingSharding from torchrec.distributed.types import ( - Awaitable, - FeatureShardingMixIn, - LazyAwaitable, NullShardedModuleContext, + NullShardingContext, ParameterSharding, - ShardedModule, ShardingEnv, ShardingType, ) +from torchrec.distributed.utils import copy_to_device from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, dtype_to_data_type, EmbeddingBagConfig, ) from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection +from torchrec.pt2.checks import is_torchdynamo_compiling from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, + FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> Union[str, Tuple[str, ...]]: + """ + Returns list of device type per shard if table is sharded across + different device type, else reutrns single device type for the + table parameter. + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] + else: + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list + + +def get_device_from_sharding_infos( + emb_shard_infos: List[EmbeddingShardingInfo], +) -> Union[str, Tuple[str, ...]]: + res = list( + { + get_device_from_parameter_sharding(ps.param_sharding) + for ps in emb_shard_infos + } + ) + assert len(res) == 1, "All shards should be on the same type of device" + return res[0] + + +def get_device_for_first_shard_from_sharding_infos( + emb_shard_infos: List[EmbeddingShardingInfo], +) -> str: + device_type = get_device_from_sharding_infos(emb_shard_infos) + return device_type[0] if isinstance(device_type, tuple) else device_type + + +torch.fx.wrap("len") + + +@torch.fx.wrap +def flatten_feature_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor: + return features.flatten_lengths() if features.lengths().dim() > 1 else features + + def create_infer_embedding_bag_sharding( sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, + device: Optional[torch.device] = None, ) -> EmbeddingSharding[ - NullShardingContext, SparseFeaturesList, List[torch.Tensor], torch.Tensor + NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor ]: + propogate_device: bool = get_propogate_device() + device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = ( + get_device_from_sharding_infos(sharding_infos) + ) if sharding_type == ShardingType.TABLE_WISE.value: - return InferTwEmbeddingSharding(sharding_infos, env, device=None) + return InferTwEmbeddingSharding( + sharding_infos, env, device=device if propogate_device else None + ) + elif sharding_type == ShardingType.ROW_WISE.value: + return InferRwPooledEmbeddingSharding( + sharding_infos, + env, + device=device if propogate_device else None, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return InferCwPooledEmbeddingSharding( + sharding_infos, + env, + device=device if propogate_device else None, + permute_embeddings=True, + ) else: raise ValueError(f"Sharding type not supported {sharding_type}") class ShardedQuantEmbeddingBagCollection( - ShardedModule[ - ListOfSparseFeaturesList, + ShardedQuantEmbeddingModuleState[ + ListOfKJTList, List[List[torch.Tensor]], KeyedTensor, NullShardedModuleContext, @@ -79,174 +173,193 @@ def __init__( self, module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], - env: ShardingEnv, + env: Union[ShardingEnv, Dict[str, ShardingEnv]], # support for Hybrid Sharding fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, ) -> None: super().__init__() - self._embedding_bag_configs: List[ - EmbeddingBagConfig - ] = module.embedding_bag_configs() - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( + self._embedding_bag_configs: List[EmbeddingBagConfig] = ( + module.embedding_bag_configs() + ) + self._sharding_type_device_group_to_sharding_infos: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] + ] = create_sharding_infos_by_sharding_device_group( module, table_name_to_parameter_sharding, "embedding_bags.", fused_params ) - self._sharding_type_to_sharding: Dict[ - str, + self._sharding_type_device_group_to_sharding: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], EmbeddingSharding[ NullShardingContext, - SparseFeaturesList, + InputDistOutputs, List[torch.Tensor], torch.Tensor, ], ] = { - sharding_type: create_infer_embedding_bag_sharding( - sharding_type, embedding_confings, env + (sharding_type, device_group): create_infer_embedding_bag_sharding( + sharding_type, + embedding_configs, + ( + env + if not isinstance(env, Dict) + else env[ + get_device_for_first_shard_from_sharding_infos( + embedding_configs + ) + ] + ), + device if get_propogate_device() else None, ) - for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items() + for ( + sharding_type, + device_group, + ), embedding_configs in self._sharding_type_device_group_to_sharding_infos.items() } - + self._device = device self._is_weighted: bool = module.is_weighted() - self._input_dists: List[nn.Module] = [] self._lookups: List[nn.Module] = [] - self._create_lookups(fused_params) - self._output_dists: List[nn.Module] = [] + self._create_lookups(fused_params, device) + + # Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime). + self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList() + self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] - self._feature_splits: List[int] = [] - self._features_order: List[int] = [] # forward pass flow control - self._has_uninitialized_input_dist: bool = True self._has_uninitialized_output_dist: bool = True - self._has_features_permute: bool = True - # This provides consistency between this class and the EmbeddingBagCollection's - # nn.Module API calls (state_dict, named_modules, etc) - # Currently, Sharded Quant EBC only uses TW sharding, and returns non-sharded tensors as part of state dict - # TODO - revisit if we state_dict can be represented as sharded tensor - self.embedding_bags: nn.ModuleDict = nn.ModuleDict() - for table in self._embedding_bag_configs: - self.embedding_bags[table.name] = torch.nn.Module() - - for _sharding_type, lookup in zip( - self._sharding_type_to_sharding.keys(), self._lookups - ): - lookup_state_dict = lookup.state_dict() - for key in lookup_state_dict: - if not key.endswith(".weight"): - continue - table_name = key[: -len(".weight")] - # Register as buffer because this is an inference model, and can potentially use uint8 types. - self.embedding_bags[table_name].register_buffer( - "weight", lookup_state_dict[key] - ) + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig] = ( + get_tbes_to_register_from_iterable(self._lookups) + ) - def _create_input_dist( - self, - input_feature_names: List[str], - device: torch.device, - ) -> None: - feature_names: List[str] = [] - for sharding in self._sharding_type_to_sharding.values(): - self._input_dists.append(sharding.create_input_dist()) - feature_names.extend( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - self._feature_splits.append( - len( - sharding.id_score_list_feature_names() - if self._is_weighted - else sharding.id_list_feature_names() - ) - ) + self._tbes_configs: Dict[ + IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig + ] = tbes - if feature_names == input_feature_names: - self._has_features_permute = False + # Optional registration of TBEs for model post processing utilities + if is_fused_param_register_tbe(fused_params): + self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(tbes.keys()) + + quant_state_dict_split_scale_bias = ( + is_fused_param_quant_state_dict_split_scale_bias(fused_params) + ) + + if quant_state_dict_split_scale_bias: + self._initialize_torch_state( + tbes=tbes, + table_name_to_parameter_sharding=table_name_to_parameter_sharding, + tables_weights_prefix="embedding_bags", + ) else: - for f in feature_names: - self._features_order.append(input_feature_names.index(f)) - self.register_buffer( - "_features_order_tensor", - torch.tensor(self._features_order, device=device, dtype=torch.int32), - persistent=False, + table_wise_sharded_only: bool = all( + sharding_type == ShardingType.TABLE_WISE.value + for ( + sharding_type, + _, + ) in self._sharding_type_device_group_to_sharding.keys() ) + assert ( + table_wise_sharded_only + ), "ROW_WISE,COLUMN_WISE shardings can be used only in 'quant_state_dict_split_scale_bias' mode, specify fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS]=True to __init__ argument" + + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + for table in self._embedding_bag_configs: + self.embedding_bags[table.name] = torch.nn.Module() + + for _, lookup in zip( + self._sharding_type_device_group_to_sharding.keys(), self._lookups + ): + lookup_state_dict = lookup.state_dict() + for key in lookup_state_dict: + if key.endswith(".weight"): + table_name = key[: -len(".weight")] + self.embedding_bags[table_name].register_buffer( + "weight", lookup_state_dict[key] + ) + + self._input_dist_module: ShardedQuantEbcInputDist = ShardedQuantEbcInputDist( + self._sharding_type_device_group_to_sharding, self._device + ) + + def tbes_configs( + self, + ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: + return self._tbes_configs + + def sharding_type_device_group_to_sharding_infos( + self, + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: + return self._sharding_type_device_group_to_sharding_infos + + def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: + return self._embedding_bag_configs def _create_lookups( self, fused_params: Optional[Dict[str, Any]], + device: Optional[torch.device] = None, ) -> None: - for sharding in self._sharding_type_to_sharding.values(): - self._lookups.append(sharding.create_lookup(fused_params=fused_params)) + for sharding in self._sharding_type_device_group_to_sharding.values(): + self._lookups.append( + sharding.create_lookup( + device=device, + fused_params=fused_params, + ) + ) def _create_output_dist(self, device: Optional[torch.device] = None) -> None: - for sharding in self._sharding_type_to_sharding.values(): + for sharding in self._sharding_type_device_group_to_sharding.values(): self._output_dists.append(sharding.create_output_dist(device)) self._embedding_names.extend(sharding.embedding_names()) self._embedding_dims.extend(sharding.embedding_dims()) # pyre-ignore [14] + # pyre-ignore def input_dist( self, ctx: NullShardedModuleContext, features: KeyedJaggedTensor - ) -> Awaitable[ListOfSparseFeaturesList]: - if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys(), features.device()) - self._has_uninitialized_input_dist = False + ) -> ListOfKJTList: + input_dist_outputs = self._input_dist_module(features) + if self._has_uninitialized_output_dist: self._create_output_dist(features.device()) self._has_uninitialized_output_dist = False - with torch.no_grad(): - if self._has_features_permute: - features = features.permute( - self._features_order, - # pyre-ignore [6] - self._features_order_tensor, - ) - features_by_shards = features.split( - self._feature_splits, - ) - awaitables = [ - module( - SparseFeatures( - id_list_features=None - if self._is_weighted - else features_by_shard, - id_score_list_features=features_by_shard - if self._is_weighted - else None, - ) - ).wait() # a dummy wait since now length indices comm is splited - for module, features_by_shard in zip( - self._input_dists, features_by_shards - ) - ] - return ListOfSparseFeaturesListAwaitable(awaitables) + + return input_dist_outputs def compute( self, ctx: NullShardedModuleContext, - dist_input: ListOfSparseFeaturesList, + dist_input: ListOfKJTList, ) -> List[List[torch.Tensor]]: - return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] + # syntax for torchscript + return [lookup.forward(dist_input[i]) for i, lookup in enumerate(self._lookups)] + # pyre-ignore def output_dist( self, ctx: NullShardedModuleContext, output: List[List[torch.Tensor]], - ) -> LazyAwaitable[KeyedTensor]: - return EmbeddingBagCollectionAwaitable( - awaitables=[ - dist(embeddings) for dist, embeddings in zip(self._output_dists, output) + ) -> KeyedTensor: + return construct_output_kt( + embeddings=[ + dist.forward(output[i]) for i, dist in enumerate(self._output_dists) ], embedding_dims=self._embedding_dims, embedding_names=self._embedding_names, ) + # pyre-ignore def compute_and_output_dist( - self, ctx: NullShardedModuleContext, input: ListOfSparseFeaturesList - ) -> LazyAwaitable[KeyedTensor]: + self, ctx: NullShardedModuleContext, input: ListOfKJTList + ) -> KeyedTensor: return self.output_dist(ctx, self.compute(ctx, input)) + # pyre-ignore + def forward(self, *input, **kwargs) -> KeyedTensor: + ctx = self.create_context() + dist_input = self.input_dist(ctx, *input, **kwargs) + return self.compute_and_output_dist(ctx, dist_input) + def copy(self, device: torch.device) -> nn.Module: if self._has_uninitialized_output_dist: self._create_output_dist(device) @@ -254,11 +367,20 @@ def copy(self, device: torch.device) -> nn.Module: return super().copy(device) @property - def shardings(self) -> Dict[str, FeatureShardingMixIn]: + def shardings( + self, + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]: # pyre-ignore [7] - return self._sharding_type_to_sharding + return self._sharding_type_device_group_to_sharding def create_context(self) -> NullShardedModuleContext: + if is_torchdynamo_compiling(): + # Context creation is not supported by dynamo yet. + # Context is not needed for TW sharding => + # Unblocking dynamo TW with None. + # pyre-ignore + return None + return NullShardedModuleContext() @@ -269,15 +391,308 @@ def shard( self, module: QuantEmbeddingBagCollection, params: Dict[str, ParameterSharding], - env: ShardingEnv, + env: Union[ShardingEnv, Dict[str, ShardingEnv]], device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedQuantEmbeddingBagCollection: fused_params = self.fused_params if self.fused_params else {} fused_params["output_dtype"] = data_type_to_sparse_type( dtype_to_data_type(module.output_dtype()) ) - return ShardedQuantEmbeddingBagCollection(module, params, env, fused_params) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + + return ShardedQuantEmbeddingBagCollection( + module, params, env, fused_params, device=device + ) @property def module_type(self) -> Type[QuantEmbeddingBagCollection]: return QuantEmbeddingBagCollection + + +class ShardedQuantFeatureProcessedEmbeddingBagCollection( + ShardedQuantEmbeddingBagCollection, +): + def __init__( + self, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + feature_processor: Optional[FeatureProcessorsCollection] = None, + ) -> None: + super().__init__( + module, + table_name_to_parameter_sharding, + env, + fused_params, + device, + ) + assert feature_processor is not None + device_type: str = self._device.type if self._device is not None else "cuda" + self.feature_processors_per_rank: nn.ModuleList = torch.nn.ModuleList() + feature_processor_device = None + for _, param in feature_processor.named_parameters(): + if feature_processor_device is None: + feature_processor_device = param.device + elif feature_processor_device != param.device: + raise RuntimeError( + f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}" + ) + + for _, buffer in feature_processor.named_buffers(): + if feature_processor_device is None: + feature_processor_device = buffer.device + elif feature_processor_device != buffer.device: + raise RuntimeError( + f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}" + ) + + if feature_processor_device is None: + for _ in range(env.world_size): + self.feature_processors_per_rank.append(feature_processor) + else: + for i in range(env.world_size): + # Generic copy, for example initailized on cpu but -> sharding as meta + self.feature_processors_per_rank.append( + copy.deepcopy(feature_processor) + if device_type == "meta" + else copy_to_device( + feature_processor, + feature_processor_device, + ( + torch.device(f"{device_type}:{i}") + if device_type == "cuda" + else torch.device(f"{device_type}") + ), + ) + ) + + def apply_feature_processor( + self, + kjt_list: KJTList, + ) -> KJTList: + l: List[KeyedJaggedTensor] = [] + for i in range(len(self.feature_processors_per_rank)): + l.append(self.feature_processors_per_rank[i](kjt_list[i])) + return KJTList(l) + + def compute( + self, + ctx: NullShardedModuleContext, + dist_input: ListOfKJTList, # List_per_sharding[List_per_rank[KJT]] + ) -> List[List[torch.Tensor]]: + return [ + lookup.forward(self.apply_feature_processor(dist_input[i])) + for i, lookup in enumerate(self._lookups) + ] + + +class QuantFeatureProcessedEmbeddingBagCollectionSharder( + BaseQuantEmbeddingSharder[QuantFeatureProcessedEmbeddingBagCollection] +): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.QUANT.value] + + def shard( + self, + module: QuantFeatureProcessedEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantEmbeddingBagCollection: + qebc = module + assert isinstance(qebc, QuantEmbeddingBagCollection) + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(qebc.output_dtype()) + ) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + qebc, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + qebc, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + + return ShardedQuantFeatureProcessedEmbeddingBagCollection( + qebc, + params, + env, + fused_params, + device=device, + feature_processor=module.feature_processor, + ) + + @property + def module_type(self) -> Type[QuantFeatureProcessedEmbeddingBagCollection]: + return QuantFeatureProcessedEmbeddingBagCollection + + +class ShardedQuantEbcInputDist(torch.nn.Module): + """ + This module implements distributed inputs of a ShardedQuantEmbeddingBagCollection. + + Args: + sharding_type_device_group_to_sharding (Dict[ + Tuple[str, str], + EmbeddingSharding[ + NullShardingContext, + KJTList, + List[torch.Tensor], + torch.Tensor, + ], + ]): map from sharding type to EmbeddingSharding. + device (Optional[torch.device]): default compute device. + + Example:: + + sqebc_input_dist = ShardedQuantEbcInputDist( + sharding_type_device_group_to_sharding={ + (ShardingType.TABLE_WISE, "cpu"): InferTwSequenceEmbeddingSharding( + [], + ShardingEnv( + world_size=2, + rank=0, + pg=0, + ), + torch.device("cpu") + ) + }, + device=torch.device("cpu"), + ) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + sqebc_input_dist(features) + """ + + def __init__( + self, + sharding_type_device_group_to_sharding: Dict[ + Tuple[str, Union[str, Tuple[str, ...]]], + EmbeddingSharding[ + NullShardingContext, + InputDistOutputs, + List[torch.Tensor], + torch.Tensor, + ], + ], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._sharding_type_device_group_to_sharding = ( + sharding_type_device_group_to_sharding + ) + self._device = device + + self._shardings: List[ + EmbeddingSharding[ + NullShardingContext, + InputDistOutputs, + List[torch.Tensor], + torch.Tensor, + ] + ] = list(sharding_type_device_group_to_sharding.values()) + + self._input_dists: List[nn.Module] = [] + + self._features_order: List[int] = [] + self._feature_names: List[List[str]] = [ + sharding.feature_names() for sharding in self._shardings + ] + self._feature_splits: List[int] = [ + len(sharding) for sharding in self._feature_names + ] + + # forward pass flow control + self._has_uninitialized_input_dist: bool = True + self._has_features_permute: bool = True + + def _create_input_dist( + self, + input_feature_names: List[str], + features_device: torch.device, + input_dist_device: Optional[torch.device] = None, + ) -> None: + flat_feature_names: List[str] = [ + feature_name + for sharding_feature_name in self._feature_names + for feature_name in sharding_feature_name + ] + self._input_dists = [ + sharding.create_input_dist(device=input_dist_device) + for sharding in self._shardings + ] + + if flat_feature_names == input_feature_names: + self._has_features_permute = False + else: + for f in flat_feature_names: + self._features_order.append(input_feature_names.index(f)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=features_device, dtype=torch.int32 + ), + persistent=False, + ) + + def forward(self, features: KeyedJaggedTensor) -> ListOfKJTList: + """ + Args: + features (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + ListOfKJTList + """ + if self._has_uninitialized_input_dist: + self._create_input_dist( + features.keys(), + features.device(), + self._device, + ) + self._has_uninitialized_input_dist = False + with torch.no_grad(): + if self._has_features_permute: + features = features.permute( + self._features_order, + # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` + # but got `Union[Module, Tensor]`. + self._features_order_tensor, + ) + else: + features = flatten_feature_lengths(features) + features_by_shards = ( + [features] + if len(self._feature_splits) == 1 + else features.split(self._feature_splits) + ) + return ListOfKJTList( + [ + self._input_dists[i].forward(features_by_shards[i]).features + for i in range(len(self._input_dists)) + ] + ) diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py new file mode 100644 index 000000000..1de388e1b --- /dev/null +++ b/torchrec/distributed/quant_state.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union + +import torch +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from torch.distributed import _remote_device +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensorBase, + ShardedTensorMetadata, + ShardMetadata, +) +from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + ShardedEmbeddingModule, +) +from torchrec.distributed.types import ParameterSharding, ShardingType +from torchrec.modules.embedding_configs import DataType +from torchrec.streamable import Multistreamable +from torchrec.tensor_types import UInt2Tensor, UInt4Tensor + +Out = TypeVar("Out") +CompIn = TypeVar("CompIn") +DistOut = TypeVar("DistOut") +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + +def _append_table_shard( + d: Dict[str, List[Shard]], table_name: str, shard: Shard +) -> None: + if table_name not in d: + d[table_name] = [] + d[table_name].append(shard) + + +def post_state_dict_hook( + # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] + # pyre-ignore [24] + module: ShardedEmbeddingModule, + destination: Dict[str, torch.Tensor], + prefix: str, + _local_metadata: Dict[str, Any], + tables_weights_prefix: str, # "embedding_bags" or "embeddings" +) -> None: + for ( + table_name, + sharded_t, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `items`. + ) in module._table_name_to_sharded_tensor.items(): + destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = sharded_t + + for sfx, dict_sharded_t, dict_t_list in [ + ( + "weight_qscale", + module._table_name_to_sharded_tensor_qscale, + module._table_name_to_tensors_list_qscale, + ), + ( + "weight_qbias", + module._table_name_to_sharded_tensor_qbias, + module._table_name_to_tensors_list_qbias, + ), + ]: + for ( + table_name, + sharded_t, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `items`. + ) in dict_sharded_t.items(): + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = ( + sharded_t + ) + for ( + table_name, + t_list, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `items`. + ) in dict_t_list.items(): + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = t_list + + +class ShardedQuantEmbeddingModuleState( + ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx] +): + def _initialize_torch_state( # noqa: C901 + # Union[ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection] + self, + tbes: Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig], + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + tables_weights_prefix: str, # "embedding_bags" or "embeddings" + ) -> None: # noqa + # State is prepared only in "quant_state_dict_split_scale_bias" mode + assert ( + tables_weights_prefix == "embedding_bags" + or tables_weights_prefix == "embeddings" + ) + + # weight + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_local_shards`. + self._table_name_to_local_shards: Dict[str, List[Shard]] = {} + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_sharded_tensor`. + self._table_name_to_sharded_tensor: Dict[ + str, Union[torch.Tensor, ShardedTensorBase] + ] = {} + + # weight_qscale + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_local_shards_qscale`. + self._table_name_to_local_shards_qscale: Dict[str, List[Shard]] = {} + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_sharded_tensor_qscale`. + self._table_name_to_sharded_tensor_qscale: Dict[ + str, Union[torch.Tensor, ShardedTensorBase] + ] = {} + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_tensors_list_qscale`. + self._table_name_to_tensors_list_qscale: Dict[str, List[torch.Tensor]] = {} + + # weight_qbias + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_local_shards_qbias`. + self._table_name_to_local_shards_qbias: Dict[str, List[Shard]] = {} + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_sharded_tensor_qbias`. + self._table_name_to_sharded_tensor_qbias: Dict[ + str, Union[torch.Tensor, ShardedTensorBase] + ] = {} + # pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute + # `_table_name_to_tensors_list_qbias`. + self._table_name_to_tensors_list_qbias: Dict[str, List[torch.Tensor]] = {} + + for tbe, config in tbes.items(): + for (tbe_split_w, tbe_split_qscale, tbe_split_qbias), table in zip( + tbe.split_embedding_weights_with_scale_bias(split_scale_bias_mode=2), + config.embedding_tables, + ): + if table.data_type == DataType.INT4: + tbe_split_w = UInt4Tensor(tbe_split_w) + elif table.data_type == DataType.INT2: + tbe_split_w = UInt2Tensor(tbe_split_w) + + # weight shards section: + assert table.local_metadata + metadata: ShardMetadata = copy.deepcopy(table.local_metadata) + metadata.shard_sizes = [tbe_split_w.size(0), tbe_split_w.size(1)] + + # TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta" + # pyre-ignore + if metadata.placement.device != tbe_split_w.device: + metadata.placement = _remote_device(tbe_split_w.device) + _append_table_shard( + # pyre-fixme[6]: For 1st argument expected `Dict[str, + # List[Shard]]` but got `Union[Tensor, Module]`. + self._table_name_to_local_shards, + table.name, + Shard(tensor=tbe_split_w, metadata=metadata), + ) + # end of weight shards section + + # weight_qscale & weight_qbias section: + # For RW - ShardedTensorBase + # For CW - List[Tensor] that logically corresponds to the same unsharded Tensor, but present on each sharded rank + for ( + tbe_split_qparam, + table_name_to_local_shards, + table_name_to_tensors_list, + ) in [ + ( + tbe_split_qscale, + self._table_name_to_local_shards_qscale, + self._table_name_to_tensors_list_qscale, + ), + ( + tbe_split_qbias, + self._table_name_to_local_shards_qbias, + self._table_name_to_tensors_list_qbias, + ), + ]: + assert table.local_metadata + metadata: ShardMetadata = copy.deepcopy(table.local_metadata) + shard_sizes = metadata.shard_sizes + shard_offsets = metadata.shard_offsets + + shard_sizes_cols = shard_sizes[1] + shard_offsets_cols = shard_offsets[1] + + parameter_sharding: ParameterSharding = ( + table_name_to_parameter_sharding[table.name] + ) + sharding_type: str = parameter_sharding.sharding_type + + if sharding_type == ShardingType.COLUMN_WISE.value: + # pyre-fixme[58]: `not in` is not supported for right + # operand type `Union[Tensor, Module]`. + if table.name not in table_name_to_tensors_list: + assert parameter_sharding.ranks + num_shards: int = len(parameter_sharding.ranks) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Unio... + table_name_to_tensors_list[table.name] = [ + torch.empty([]) + ] * num_shards + + column_idx = int(shard_offsets_cols / shard_sizes_cols) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[No... + table_name_to_tensors_list[table.name][ + column_idx + ] = tbe_split_qparam + else: + qmetadata = ShardMetadata( + shard_offsets=metadata.shard_offsets, + shard_sizes=[ + tbe_split_qparam.shape[0], + tbe_split_qparam.shape[1], + ], + # pyre-ignore + placement=table.local_metadata.placement, + ) + # TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta" + if qmetadata.placement.device != tbe_split_qparam.device: + qmetadata.placement = _remote_device( + tbe_split_qparam.device + ) + _append_table_shard( + # pyre-fixme[6]: For 1st argument expected `Dict[str, + # List[Shard]]` but got `Union[Tensor, Module]`. + table_name_to_local_shards, + table.name, + Shard(tensor=tbe_split_qparam, metadata=qmetadata), + ) + # end of weight_qscale & weight_qbias section + + for table_name_to_local_shards, table_name_to_sharded_tensor in [ + (self._table_name_to_local_shards, self._table_name_to_sharded_tensor), + ( + self._table_name_to_local_shards_qscale, + self._table_name_to_sharded_tensor_qscale, + ), + ( + self._table_name_to_local_shards_qbias, + self._table_name_to_sharded_tensor_qbias, + ), + ]: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `items`. + for table_name, local_shards in table_name_to_local_shards.items(): + if len(local_shards) == 1: + # Single Tensor per table (TW sharding) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + table_name_to_sharded_tensor[table_name] = local_shards[0].tensor + continue + + # ShardedTensor per table + global_rows = max( + [ + ls.metadata.shard_offsets[0] + ls.metadata.shard_sizes[0] + for ls in local_shards + ] + ) + global_cols = max( + [ + ls.metadata.shard_offsets[1] + ls.metadata.shard_sizes[1] + for ls in local_shards + ] + ) + global_metadata: ShardedTensorMetadata = ShardedTensorMetadata( + shards_metadata=[ls.metadata for ls in local_shards], + size=torch.Size([global_rows, global_cols]), + ) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _Nes... + table_name_to_sharded_tensor[table_name] = ( + ShardedTensorBase._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=global_metadata, + ) + ) + + self._register_state_dict_hook( + partial(post_state_dict_hook, tables_weights_prefix=tables_weights_prefix) + ) + + def _load_from_state_dict( + # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] + self, + state_dict: Mapping[str, Any], + prefix: str, + # pyre-ignore + local_metadata, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + dst_state_dict = self.state_dict() + _missing_keys: List[str] = [] + _unexpected_keys: List[str] = list(state_dict.keys()) + for name, dst_tensor in dst_state_dict.items(): + src_state_dict_name = prefix + name + if src_state_dict_name not in state_dict: + _missing_keys.append(src_state_dict_name) + continue + + src_tensor = state_dict[src_state_dict_name] + if isinstance(dst_tensor, ShardedTensorBase) and isinstance( + src_tensor, ShardedTensorBase + ): + # sharded to sharded model, only identically sharded + for dst_local_shard in dst_tensor.local_shards(): + copied: bool = False + for src_local_shard in src_tensor.local_shards(): + if ( + dst_local_shard.metadata.shard_offsets + == src_local_shard.metadata.shard_offsets + and dst_local_shard.metadata.shard_sizes + == src_local_shard.metadata.shard_sizes + ): + dst_local_shard.tensor.copy_(src_local_shard.tensor) + copied = True + break + assert copied, "Incompatible state_dict" + elif isinstance(dst_tensor, ShardedTensorBase) and isinstance( + src_tensor, torch.Tensor + ): + # non_sharded to sharded model + for dst_local_shard in dst_tensor.local_shards(): + dst_tensor = dst_local_shard.tensor + assert src_tensor.ndim == dst_tensor.ndim + meta = dst_local_shard.metadata + t = src_tensor.detach() + rows_from = meta.shard_offsets[0] + rows_to = rows_from + meta.shard_sizes[0] + if t.ndim == 1: + dst_tensor.copy_(t[rows_from:rows_to]) + elif t.ndim == 2: + cols_from = meta.shard_offsets[1] + cols_to = cols_from + meta.shard_sizes[1] + dst_tensor.copy_( + t[ + rows_from:rows_to, + cols_from:cols_to, + ] + ) + else: + raise RuntimeError("Tensors with ndim > 2 are not supported") + elif isinstance(dst_tensor, list) and isinstance(src_tensor, torch.Tensor): + # non_sharded to CW columns qscale, qbias (one to many) + for t in dst_tensor: + assert isinstance(t, torch.Tensor) + t.copy_(src_tensor) + else: + dst_tensor.copy_(src_tensor) + + _unexpected_keys.remove(src_state_dict_name) + missing_keys.extend(_missing_keys) + unexpected_keys.extend(_unexpected_keys) + + +@dataclass +class WeightSpec: + fqn: str # "ebc.embedding_bags.table_0.weight" + shard_offsets: List[int] # shard offsets + shard_sizes: List[int] # shard sizes + sharding_type: Optional[str] # e.g. ShardingType.ROW_WISE.value=="row_wise" + + +def sharded_tbes_weights_spec( + sharded_model: torch.nn.Module, +) -> Dict[str, WeightSpec]: + # OUTPUT: + # Example: + # { + # tbes.0 + # table_0 in tbes.0 + # "ebc.tbes.0.0.table_0.weight": WeightSpec("ebc.embedding_bags.table_0.weight", [0, 0], [500, 192]) + # "ebc.tbes.0.0.table_0.weight_qscale":WeightSpec("ebc.embedding_bags.table_0.weight_qscale", [0, 0], [500, 2]) + # "ebc.tbes.0.0.table_0.weight_qbias":WeightSpec("ebc.embedding_bags.table_0.weight_qbias", [0, 0], [500, 2]) + # table_1 in tbes.0 + # "ebc.tbes.0.1.table_1.weight": WeightSpec("ebc.embedding_bags.table_1.weight", [0, 0], [500, 192]) + # "ebc.tbes.0.1.table_1.weight_qscale":WeightSpec("ebc.embedding_bags.table_1.weight_qscale", [0, 0], [500, 2]) + # "ebc.tbes.0.1.table_1.weight_qbias":WeightSpec("ebc.embedding_bags.table_1.weight_qbias", [0, 0], [500, 2]) + # tbes.1 + # table_0 in tbes.1 + # "ebc.tbes.1.0.table_0.weight": WeightSpec("ebc.embedding_bags.table_0.weight", [500, 0], [500, 192]) + # "ebc.tbes.1.0.table_0.weight_qscale":WeightSpec("ebc.embedding_bags.table_0.weight_qscale", [500, 0], [500, 2]) + # "ebc.tbes.1.0.table_0.weight_qbias":WeightSpec("ebc.embedding_bags.table_0.weight_qbias", [500, 0], [500, 2]) + # table_1 in tbes.1 + # "ebc.tbes.1.1.table_1.weight": WeightSpec("ebc.embedding_bags.table_1.weight", [500, 0], [500, 192]) + # "ebc.tbes.1.1.table_1.weight_qscale":WeightSpec("ebc.embedding_bags.table_1.weight_qscale", [500, 0], [500, 2]) + # "ebc.tbes.1.1.table_1.weight_qbias":WeightSpec("ebc.embedding_bags.table_1.weight_qbias", [500, 0], [500, 2]) + # } + # In the format of ebc.tbes.i.j.table_k.weight, where i is the index of the TBE, j is the index of the embedding bag within TBE i, k is the index of the original table set in the ebc embedding_configs + # e.g. ebc.tbes.1.1.table_1.weight, it represents second embedding bag within the second TBE. This part of weight is from a shard of table_1 + + ret: Dict[str, WeightSpec] = {} + for module_fqn, module in sharded_model.named_modules(): + type_name: str = type(module).__name__ + is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name + is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name + is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name + + if is_sqebc or is_sqec or is_sqmcec: + assert ( + is_sqec + is_sqebc + is_sqmcec == 1 + ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true" + tbes_configs: Dict[ + IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig + ] = module.tbes_configs() + table_shardings: Dict[str, str] = {} + + sharding_type_device_group_to_sharding_infos: Dict[ + Tuple[str, str], List[EmbeddingShardingInfo] + ] = module.sharding_type_device_group_to_sharding_infos() + + for ( + (sharding_type, _), + sharding_infos, + ) in sharding_type_device_group_to_sharding_infos.items(): + for info in sharding_infos: + table_shardings[info.embedding_config.name] = sharding_type + + for tbe_idx, (_tbe, config) in enumerate(tbes_configs.items()): + tables = config.embedding_tables + for table_idx, table in enumerate(tables): + table_name: str = table.name + # pyre-ignore + table_metadata: ShardMetadata = table.local_metadata + # TODO(ivankobzarev) Switch to use table_metadata.shard_sizes when it works correctly with int4 quantized modules + shard_sizes: List[int] = [table.local_rows, table.local_cols] + shard_offsets: List[int] = table_metadata.shard_offsets + s: str = "embedding_bags" if is_sqebc else "embeddings" + s = ("_embedding_module." if is_sqmcec else "") + s + unsharded_fqn_weight: str = f"{module_fqn}.{s}.{table_name}.weight" + + sharded_fqn_weight: str = ( + f"{module_fqn}.tbes.{tbe_idx}.{table_idx}.{table_name}.weight" + ) + sharding_type: str = table_shardings[table_name] + ret[sharded_fqn_weight] = WeightSpec( + fqn=unsharded_fqn_weight, + shard_offsets=shard_offsets, + shard_sizes=shard_sizes, + sharding_type=sharding_type, + ) + + for qcomponent in ["qscale", "qbias"]: + qcomp_shard_offsets: List[int] = copy.deepcopy(shard_offsets) + # handling CW - no columns shift for qscale/qbias + qcomp_shard_offsets[1] = 0 + qcomp_shard_sizes: List[int] = copy.deepcopy(shard_sizes) + # Assuming qscale and qbias are always torch.half (float16), represented as tensor of byte type => sizeof(float16) == 2 (bytes) + qcomp_shard_sizes[1] = 2 + + ret[f"{sharded_fqn_weight}_{qcomponent}"] = WeightSpec( + fqn=f"{unsharded_fqn_weight}_{qcomponent}", + shard_offsets=qcomp_shard_offsets, + shard_sizes=qcomp_shard_sizes, + sharding_type=sharding_type, + ) + return ret diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 1d2efbbd3..0a27711a7 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -5,36 +5,162 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple, Type +# pyre-strict + +from typing import Callable, Dict, List, Optional, Type, Union import torch import torch.distributed as dist from torch import nn +from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size +from torchrec.distributed.global_settings import get_propogate_device from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import EmbeddingShardingPlanner -from torchrec.distributed.planner.types import Topology -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.sharding_plan import ( + get_module_to_default_sharders, + ParameterShardingGenerator, +) +from torchrec.distributed.types import ( + ModuleSharder, + ModuleShardingPlan, + ShardingEnv, + ShardingPlan, +) +from torchrec.distributed.utils import init_parameters +from torchrec.modules.utils import reset_module_states_post_sharding def _join_module_path(path: str, name: str) -> str: return (path + "." + name) if path else name +# pyre-ignore +@contract() def shard( + module: nn.Module, + plan: Union[ + ModuleShardingPlan, + Dict[str, ParameterShardingGenerator], + ParameterShardingGenerator, + ], + env: Optional[ShardingEnv] = None, + device: Optional[torch.device] = None, + sharder: Optional[ModuleSharder[nn.Module]] = None, +) -> nn.Module: + """ + Replaces this module with its sharded variant + + This will leave the other parts of the model unaffected. + + It returns the original module + + Args: + module (nn.Module): module to wrap. + env (Optional[ShardingEnv]): sharding environment that has the process group. + device (Optional[torch.device]): compute device, defaults to cpu. + plan (Union[ModuleShardingPlan, Dict[str, ParameterShardingGenerator], ParameterShardingGenerator]): + dict of ParameterSharding (materized plan) or ParameterShardingGenerator (which will be run to produce ParameterSharding). + If single ParameterShardingGenerator is supplied, it will be applied to all module parameters. + sharder (Optional[List[ModuleSharder[nn.Module]]]): sharder to use, default is picked from `get_default_sharders()`. + + Example: + + ebc = EmbeddingBagCollection() + sharded_ebc = shard(ebc, table_row_wise(host_index=0)) + assert isinstance(sharded_ebc, ShardedEmbeddingBagCollection) + """ + torch._C._log_api_usage_once("torchrec.distributed.shard") + return _shard(module, plan, env, device, sharder) + + +def _shard( + module: nn.Module, + plan: Union[ + ModuleShardingPlan, + Dict[str, ParameterShardingGenerator], + ParameterShardingGenerator, + ], + env: Optional[ShardingEnv] = None, + device: Optional[torch.device] = None, + sharder: Optional[ModuleSharder[nn.Module]] = None, +) -> nn.Module: + """ + See shard + """ + if sharder is None: + sharder = get_module_to_default_sharders().get(type(module), None) + assert ( + sharder is not None + ), f"Could not find a valid sharder type for {type(module)}" + + if env is None: + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + + if device is None: + if get_propogate_device(): + device = torch.device( + "cpu" + ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved + else: + if torch.cuda.is_available(): + device = torch.device(torch.cuda.current_device()) + else: + device = torch.device("cpu") + + if isinstance(plan, ModuleShardingPlan): + return sharder.shard(module, plan, env, device) + + # Run sharding generators. + shardable_parameters = sharder.shardable_parameters(module) + if isinstance(plan, Callable): + gen = plan + plan = {} + for table_name, param in shardable_parameters.items(): + plan[table_name] = gen( + param, + get_local_size(env.world_size), + env.world_size, + device.type, + sharder, + ) + else: + for table_name, sharding in plan.items(): + if isinstance(sharding, Callable): + param = shardable_parameters[table_name] + # pyre-fixme[6]: For 2nd argument expected `(Parameter, int, int, + # str, ModuleSharder[Module]) -> ParameterSharding` but got + # `ParameterSharding`. + plan[table_name] = sharding( + param, + get_local_size(env.world_size), + env.world_size, + device.type, + sharder, + ) + + return sharder.shard(module, plan, env, device) + + +# pyre-ignore +@contract() +def shard_modules( module: nn.Module, env: Optional[ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, -) -> Tuple[nn.Module, List[str]]: + init_params: bool = False, +) -> nn.Module: """ Replaces all sub_modules that are embedding modules with their sharded variants. This embedding_module -> sharded_embedding_module mapping is derived from the passed in sharders. This will leave the other parts of the model unaffected. - It returns the module (with replacements), as well as parameter names of the modules that were swapped out. + It returns the original module Args: module (nn.Module): module to wrap. @@ -43,7 +169,11 @@ def shard( plan (Optional[ShardingPlan]): plan to use when sharding, defaults to `EmbeddingShardingPlanner.collective_plan()`. sharders (Optional[List[ModuleSharder[nn.Module]]]): `ModuleSharders` available - to shard with, defaults to `EmbeddingBagCollectionSharder()`. + to shard with, defaults to `get_default_sharders()`. + init_params: (Optional[bool]): If ``True``, will materialize parameters and + buffers that are on meta device, and will move module to ``device``. Note that + this only applies if `device.type != "meta"``. Default: `False`. + Example:: @@ -60,6 +190,28 @@ def init_weights(m): assert isinstance(m.embedding_bag_collection, ShardedEmbeddingBagCollection) """ + torch._C._log_api_usage_once("torchrec.distributed.shard_modules") + return _shard_modules(module, env, device, plan, sharders, init_params) + + +def _shard_modules( # noqa: C901 + module: nn.Module, + # TODO: Consolidate to using Dict[str, ShardingEnv] + env: Optional[ + Union[ShardingEnv, Dict[str, ShardingEnv]] + ] = None, # Support hybrid sharding + device: Optional[torch.device] = None, + plan: Optional[ShardingPlan] = None, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_params: Optional[bool] = False, + planner: Optional[EmbeddingShardingPlanner] = None, +) -> nn.Module: + """ + See shard_modules + """ + + torch._C._log_api_usage_once("torchrec.distributed.shard_modules") + if sharders is None: sharders = get_default_sharders() @@ -69,38 +221,44 @@ def init_weights(m): env = ShardingEnv.from_process_group(pg) if device is None: - device = torch.device("cpu") + if get_propogate_device(): + device = torch.device( + "cpu" + ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved + else: + if torch.cuda.is_available(): + device = torch.device(torch.cuda.current_device()) + else: + device = torch.device("cpu") sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = { sharder.module_type: sharder for sharder in sharders } if plan is None: - planner = EmbeddingShardingPlanner( - topology=Topology( - local_world_size=get_local_size(env.world_size), - world_size=env.world_size, - compute_device=device.type, + assert isinstance( + env, ShardingEnv + ), "Currently hybrid sharding only support use manual sharding plan" + if planner is None: + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=get_local_size(env.world_size), + world_size=env.world_size, + compute_device=device.type, + ) ) - ) pg = env.process_group if pg is not None: plan = planner.collective_plan(module, sharders, pg) else: plan = planner.plan(module, sharders) - sharded_param_names: List[str] = [] - if type(module) in sharder_map: # If the top level module is itself a shardable module, return the sharded variant. # Note, we cannot do an inplace replacement in this case. - sharded_params = plan.get_plan_for_module("") - if sharded_params is not None: - sharded_module = sharder_map[type(module)].shard( - module, sharded_params, env, device - ) - sharded_param_names.extend([name for name, _ in module.named_parameters()]) - return sharded_module, sharded_param_names + return sharder_map[type(module)].shard( + module, plan.get_plan_for_module(""), env, device, "" + ) def _replace(_model: nn.Module, path: str = "") -> None: for child_name, child in _model.named_children(): @@ -110,22 +268,24 @@ def _replace(_model: nn.Module, path: str = "") -> None: sharded_params = plan.get_plan_for_module(child_path) if sharded_params is not None: sharded_module = sharder_map[type(child)].shard( - child, sharded_params, env, device + child, + sharded_params, + env, + device, + child_path, ) _model.register_module( child_name, sharded_module, ) - - sharded_param_names.extend( - [ - _join_module_path(child_path, name) - for name, _ in child.named_parameters() - ] - ) else: _replace(child, child_path) _replace(module) + if init_params and device is not None and device.type != "meta": + init_parameters(module, device) + module = module.to(device) + + reset_module_states_post_sharding(module) - return module, sharded_param_names + return module diff --git a/torchrec/distributed/sharding/cw_sequence_sharding.py b/torchrec/distributed/sharding/cw_sequence_sharding.py index 01a453ae7..643e1d815 100644 --- a/torchrec/distributed/sharding/cw_sequence_sharding.py +++ b/torchrec/distributed/sharding/cw_sequence_sharding.py @@ -5,10 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional +# pyre-strict + +from typing import Any, Dict, List, Optional import torch -from torchrec.distributed.embedding_lookup import GroupedEmbeddingsLookup +from torchrec.distributed.dist_data import SeqEmbeddingsAllToOne +from torchrec.distributed.embedding_lookup import ( + GroupedEmbeddingsLookup, + InferGroupedEmbeddingsLookup, +) from torchrec.distributed.embedding_sharding import ( BaseEmbeddingDist, BaseEmbeddingLookup, @@ -16,17 +22,24 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, - SparseFeatures, + InputDistOutputs, ) from torchrec.distributed.sharding.cw_sharding import BaseCwEmbeddingSharding -from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) from torchrec.distributed.sharding.tw_sequence_sharding import TwSequenceEmbeddingDist -from torchrec.distributed.sharding.tw_sharding import TwSparseFeaturesDist +from torchrec.distributed.sharding.tw_sharding import ( + InferTwSparseFeaturesDist, + TwSparseFeaturesDist, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class CwSequenceEmbeddingSharding( BaseCwEmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -37,14 +50,11 @@ class CwSequenceEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: assert self._pg is not None return TwSparseFeaturesDist( self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - variable_batch_size=self._variable_batch_size, + self.features_per_rank(), ) def create_lookup( @@ -67,7 +77,75 @@ def create_output_dist( assert self._pg is not None return TwSequenceEmbeddingDist( self._pg, - self.id_list_features_per_rank(), + self.features_per_rank(), device if device is not None else self._device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) + + +class InferCwSequenceEmbeddingSharding( + BaseCwEmbeddingSharding[ + InferSequenceShardingContext, + InputDistOutputs, + List[torch.Tensor], + List[torch.Tensor], + ] +): + def create_input_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesDist[InputDistOutputs]: + return InferTwSparseFeaturesDist( + features_per_rank=self.features_per_rank(), + world_size=self._world_size, + device=device if device is not None else self._device, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: + return InferGroupedEmbeddingsLookup( + grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, + world_size=self._world_size, + fused_params=fused_params, + device=device if device is not None else self._device, + ) + + def create_output_dist( + self, device: Optional[torch.device] = None + ) -> BaseEmbeddingDist[ + InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] + ]: + device = device if device is not None else self._device + assert device is not None + + dist_out = InferCwSequenceEmbeddingDist( + device, + self._world_size, + ) + return dist_out + + +class InferCwSequenceEmbeddingDist( + BaseEmbeddingDist[ + InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] + ] +): + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device=device, world_size=world_size + ) + + def forward( + self, + local_embs: List[torch.Tensor], + sharding_ctx: Optional[InferSequenceShardingContext] = None, + ) -> List[torch.Tensor]: + return self._dist(local_embs) diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index d6e48b475..90a2e6bef 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar import torch @@ -12,7 +14,12 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torch.distributed._tensor import Replicate, Shard +from torchrec.distributed.dist_data import EmbeddingsAllToOne +from torchrec.distributed.embedding_lookup import ( + GroupedPooledEmbeddingsLookup, + InferGroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embedding_sharding import ( BaseEmbeddingDist, BaseEmbeddingLookup, @@ -22,21 +29,27 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, + InputDistOutputs, ShardedEmbeddingTable, - SparseFeatures, ) from torchrec.distributed.sharding.tw_sharding import ( BaseTwEmbeddingSharding, + InferTwSparseFeaturesDist, TwPooledEmbeddingDist, TwSparseFeaturesDist, ) from torchrec.distributed.types import ( + NullShardingContext, QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingType, ShardMetadata, ) +from torchrec.distributed.utils import none_throws +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable C = TypeVar("C", bound=Multistreamable) @@ -57,14 +70,12 @@ def __init__( device: Optional[torch.device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, ) -> None: super().__init__( sharding_infos, env, device, qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, ) self._permute_embeddings = permute_embeddings if self._permute_embeddings: @@ -82,9 +93,9 @@ def _init_combined_embeddings(self) -> None: embedding_names: List[str] = super().embedding_names() embedding_dims: List[int] = super().embedding_dims() - embedding_shard_metadata: List[ - Optional[ShardMetadata] - ] = super().embedding_shard_metadata() + embedding_shard_metadata: List[Optional[ShardMetadata]] = ( + super().embedding_shard_metadata() + ) embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {} for i, (name, metadata) in enumerate( @@ -134,10 +145,9 @@ def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: - # pyre-fixme[16]: `Optional` has no attribute `size`. - world_size = self._pg.size() + world_size: int = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -148,14 +158,48 @@ def _shard( shards_metadata=shards, size=torch.Size( [ - info.embedding_config.num_embeddings, + ( + info.embedding_config.num_embeddings_post_pruning + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), info.embedding_config.embedding_dim, ] ), ) + dtensor_metadata = None + if self._env.output_dtensor: + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ), + size=( + ( + info.embedding_config.num_embeddings_post_pruning + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + # Remap rank by number of replica groups if 2D parallelism is enabled + rank = ( + # pyre-ignore[16] + self._env.remap_rank( + rank, + ShardingType.COLUMN_WISE, + ) + if self._is_2D_parallel + else rank + ) tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, @@ -167,16 +211,25 @@ def _shard( pooling=info.embedding_config.pooling, is_weighted=info.embedding_config.is_weighted, has_feature_processor=info.embedding_config.has_feature_processor, - local_rows=info.embedding_config.num_embeddings, + local_rows=( + none_throws( + info.embedding_config.num_embeddings_post_pruning + ) + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), local_cols=shards[i].shard_sizes[1], compute_kernel=EmbeddingComputeKernel( info.param_sharding.compute_kernel ), local_metadata=shards[i], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, fused_params=info.fused_params, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, + num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning, ) ) @@ -186,20 +239,26 @@ def embedding_dims(self) -> List[int]: return ( self._combined_embedding_dims if self._permute_embeddings - else super().embedding_dims() + else self.uncombined_embedding_dims() ) def embedding_names(self) -> List[str]: return ( self._combined_embedding_names if self._permute_embeddings - else super().embedding_names() + else self.uncombined_embedding_names() ) + def uncombined_embedding_dims(self) -> List[int]: + return super().embedding_dims() + + def uncombined_embedding_names(self) -> List[str]: + return super().embedding_names() + class CwPooledEmbeddingSharding( BaseCwEmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -210,14 +269,11 @@ class CwPooledEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: assert self._pg is not None return TwSparseFeaturesDist( self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - self._variable_batch_size, + self.features_per_rank(), ) def create_lookup( @@ -228,7 +284,6 @@ def create_lookup( ) -> BaseEmbeddingLookup: return GroupedPooledEmbeddingsLookup( grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, @@ -246,15 +301,122 @@ def create_output_dist( ): assert len(self._embedding_order) == len(self._embedding_dims) embedding_permute_op = PermutePooledEmbeddingsSplit( - self._embedding_dims, - self._embedding_order, - ).to(device=device) + self._embedding_dims, self._embedding_order, device=device + ) callbacks = [embedding_permute_op] assert self._pg is not None return TwPooledEmbeddingDist( - self._pg, - self._dim_sum_per_rank(), - device, - callbacks, + pg=self._pg, + dim_sum_per_rank=self._dim_sum_per_rank(), + emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature(), + device=device, + callbacks=callbacks, qcomm_codecs_registry=self.qcomm_codecs_registry, ) + + +class InferCwPooledEmbeddingSharding( + BaseCwEmbeddingSharding[ + NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor + ] +): + def create_input_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesDist[InputDistOutputs]: + return InferTwSparseFeaturesDist( + self.features_per_rank(), + self._world_size, + device if device is not None else self._device, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: + return InferGroupedPooledEmbeddingsLookup( + grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, + world_size=self._world_size, + fused_params=fused_params, + device=device if device is not None else self._device, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[NullShardingContext, List[torch.Tensor], torch.Tensor]: + device = device if device is not None else self._device + assert device is not None + + dist_out = InferCwPooledEmbeddingDist( + device, + self._world_size, + ) + + if self._permute_embeddings and self._embedding_order != list( + range(len(self._embedding_order)) + ): + return InferCwPooledEmbeddingDistWithPermute( + device, self._world_size, self._embedding_dims, self._embedding_order + ) + + return dist_out + + +class InferCwPooledEmbeddingDist( + BaseEmbeddingDist[NullShardingContext, List[torch.Tensor], torch.Tensor] +): + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._dist: EmbeddingsAllToOne = EmbeddingsAllToOne( + device=device, world_size=world_size, cat_dim=1 + ) + + def forward( + self, + local_embs: List[torch.Tensor], + sharding_ctx: Optional[NullShardingContext] = None, + ) -> torch.Tensor: + return self._dist.forward( + local_embs, + ) + + +@torch.fx.wrap +def _fx_wrap_permute( + permute_module: PermutePooledEmbeddingsSplit, input: torch.Tensor +) -> torch.Tensor: + return permute_module.forward(input) + + +class InferCwPooledEmbeddingDistWithPermute( + BaseEmbeddingDist[NullShardingContext, List[torch.Tensor], torch.Tensor] +): + def __init__( + self, + device: torch.device, + world_size: int, + embedding_dims: List[int], + permute: List[int], + ) -> None: + super().__init__() + self._dist: EmbeddingsAllToOne = EmbeddingsAllToOne( + device=device, world_size=world_size, cat_dim=1 + ) + self._permute: PermutePooledEmbeddingsSplit = PermutePooledEmbeddingsSplit( + embs_dims=embedding_dims, + permute=permute, + device=device, + ) + + def forward( + self, + local_embs: List[torch.Tensor], + sharding_ctx: Optional[NullShardingContext] = None, + ) -> torch.Tensor: + return self._permute(self._dist(local_embs)) diff --git a/torchrec/distributed/sharding/dp_sequence_sharding.py b/torchrec/distributed/sharding/dp_sequence_sharding.py index 0ae976c82..2ad11b247 100644 --- a/torchrec/distributed/sharding/dp_sequence_sharding.py +++ b/torchrec/distributed/sharding/dp_sequence_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Dict, Optional import torch @@ -14,16 +16,14 @@ BaseEmbeddingLookup, BaseSparseFeaturesDist, ) -from torchrec.distributed.embedding_types import ( - BaseGroupedFeatureProcessor, - SparseFeatures, -) +from torchrec.distributed.embedding_types import BaseGroupedFeatureProcessor from torchrec.distributed.sharding.dp_sharding import ( BaseDpEmbeddingSharding, DpSparseFeaturesDist, ) from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.distributed.types import Awaitable, NoWait +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class DpSequenceEmbeddingDist( @@ -56,7 +56,7 @@ def forward( class DpSequenceEmbeddingSharding( BaseDpEmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -66,7 +66,7 @@ class DpSequenceEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: return DpSparseFeaturesDist() def create_lookup( diff --git a/torchrec/distributed/sharding/dp_sharding.py b/torchrec/distributed/sharding/dp_sharding.py index e6875173f..6ffb52e4c 100644 --- a/torchrec/distributed/sharding/dp_sharding.py +++ b/torchrec/distributed/sharding/dp_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, TypeVar import torch @@ -23,9 +25,9 @@ EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, - SparseFeatures, ) from torchrec.distributed.types import Awaitable, NoWait, ShardingEnv, ShardMetadata +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable @@ -52,22 +54,13 @@ def __init__( self._rank: int = self._env.rank self._world_size: int = self._env.world_size sharded_tables_per_rank = self._shard(sharding_infos) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._score_grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - ( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ) = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[env.rank] - self._score_grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._score_grouped_embedding_configs_per_rank[env.rank] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[env.rank] + ) def _shard( self, @@ -108,16 +101,12 @@ def embedding_dims(self) -> List[int]: embedding_dims = [] for grouped_config in self._grouped_embedding_configs: embedding_dims.extend(grouped_config.embedding_dims()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_dims.extend(grouped_config.embedding_dims()) return embedding_dims def embedding_names(self) -> List[str]: embedding_names = [] for grouped_config in self._grouped_embedding_configs: embedding_names.extend(grouped_config.embedding_names()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_names.extend(grouped_config.embedding_names()) return embedding_names def embedding_names_per_rank(self) -> List[List[str]]: @@ -127,24 +116,22 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: embedding_shard_metadata = [] for grouped_config in self._grouped_embedding_configs: embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) return embedding_shard_metadata - def id_list_feature_names(self) -> List[str]: - id_list_feature_names = [] + def feature_names(self) -> List[str]: + feature_names = [] for grouped_config in self._grouped_embedding_configs: - id_list_feature_names.extend(grouped_config.feature_names()) - return id_list_feature_names + feature_names.extend(grouped_config.feature_names()) + return feature_names - def id_score_list_feature_names(self) -> List[str]: - id_score_list_feature_names = [] - for grouped_config in self._score_grouped_embedding_configs: - id_score_list_feature_names.extend(grouped_config.feature_names()) - return id_score_list_feature_names + def embedding_tables(self) -> List[ShardedEmbeddingTable]: + embedding_tables = [] + for grouped_config in self._grouped_embedding_configs: + embedding_tables.extend(grouped_config.embedding_tables) + return embedding_tables -class DpSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): +class DpSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ Distributes sparse features (input) to be data-parallel. """ @@ -154,8 +141,8 @@ def __init__(self) -> None: def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: + sparse_features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: """ No-op as sparse features are already distributed in data-parallel fashion. @@ -166,7 +153,7 @@ def forward( Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures. """ - return NoWait(cast(Awaitable[SparseFeatures], NoWait(sparse_features))) + return NoWait(cast(Awaitable[KeyedJaggedTensor], NoWait(sparse_features))) class DpPooledEmbeddingDist( @@ -199,7 +186,7 @@ def forward( class DpPooledEmbeddingSharding( BaseDpEmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -209,7 +196,7 @@ class DpPooledEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: return DpSparseFeaturesDist() def create_lookup( @@ -220,10 +207,12 @@ def create_lookup( ) -> BaseEmbeddingLookup: return GroupedPooledEmbeddingsLookup( grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, pg=self._env.process_group, device=device if device is not None else self._device, feature_processor=feature_processor, + # For data parallel we need to turn always gradient scaling in for weights + # because get_gradient_scaling from comm_ops only affects model_parallel tables, not DP + scale_weight_gradients=False, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py new file mode 100644 index 000000000..420c0ea24 --- /dev/null +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard +from torchrec.distributed.types import ( + ParameterSharding, + ShardedModule, + ShardedTensor, + ShardingEnv, +) + + +def shards_all_to_all( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + state_dict: Dict[str, ShardedTensor], + device: torch.device, + changed_sharding_params: Dict[str, ParameterSharding], + env: ShardingEnv, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Tuple[List[Tuple[str, int]], torch.Tensor]: + """ + Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. + Assumes ranks are ordered in ParameterSharding.ranks. + + Args: + module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed. + TODO: Update to support more modules, currently only supports ShardedEmbeddingBagCollection. + + state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors. + + device (torch.device): The device on which the output tensors will be placed. + + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + + env (ShardingEnv): The sharding environment containing world size and other distributed information. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing: + - A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank. + This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order. + - The tensor containing all shards received by the current rank after the all-to-all operation. + """ + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + # Module sharding plan is used to get the source ranks for each shard + assert hasattr(module, "module_sharding_plan") + + world_size = env.world_size + rank = dist.get_rank() + input_splits_per_rank = [[0] * world_size for _ in range(world_size)] + output_splits_per_rank = [[0] * world_size for _ in range(world_size)] + + # 0 by default, as current rank may be recieving 0 shards + num_embeddings_received = 0 + output_tensor_tensor_count = 0 + shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] + local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] + for shard_name, param in changed_sharding_params.items(): + sharded_t = state_dict[extend_shard_name(shard_name)] + assert param.ranks is not None + dst_ranks = param.ranks + # pyre-ignore + src_ranks = module.module_sharding_plan[shard_name].ranks + + # TODO: Implement changing rank sizes for beyond TW sharding + assert len(dst_ranks) == len(src_ranks) + + # index needed to distinguish between multiple shards + # within the same shardedTensor for each table + for i in range(len(src_ranks)): + dst_rank = dst_ranks[i] + src_rank = src_ranks[i] + + shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes + shard_size_dim_1 = shard_size[1] + input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1 + output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1 + if src_rank == rank: + local_shards = sharded_t.local_shards() + assert len(local_shards) == 1 + local_table_to_input_tensor_by_dst_rank[dst_rank].append( + sharded_t.local_shards()[0].tensor + ) + if dst_rank == rank: + shard_names_to_lengths_by_src_rank[src_rank].append( + (shard_name, shard_size_dim_1) + ) + # NOTE: Only need to update num_embeddings_received to be the + # num_embeddings of shards if this rank is actually recieving + # any tensors + if num_embeddings_received == 0: + num_embeddings_received = shard_size[0] + else: + # TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable + # For now, assume that shard_sizes in dim 0 are all the same + assert num_embeddings_received == shard_size[0] + output_tensor_tensor_count += shard_size[1] + + local_input_splits = input_splits_per_rank[rank] + local_output_splits = output_splits_per_rank[rank] + + local_input_tensor = torch.empty([0], device=device) + for sub_l in local_table_to_input_tensor_by_dst_rank: + for shard_info in sub_l: + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=1, + ) + + # Transposing the Tensors - because we are concatenating them along dimension 1 + # This is because dim 0 size may be different for different shards + # whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table + local_output_tensor = torch.empty( + [output_tensor_tensor_count, num_embeddings_received], device=device + ) + local_input_tensor = local_input_tensor.T.contiguous() + + assert sum(local_output_splits) == len(local_output_tensor) + assert sum(local_input_splits) == len(local_input_tensor) + dist.all_to_all_single( + output=local_output_tensor, + input=local_input_tensor, + output_split_sizes=local_output_splits, + input_split_sizes=local_input_splits, + group=env.process_group, # TODO: 2D uses env.sharding_pg + ) + + flattened_output_names_lengths = [ + shard_info + for sub_l in shard_names_to_lengths_by_src_rank + for shard_info in sub_l + ] + + return flattened_output_names_lengths, local_output_tensor + + +def update_state_dict_post_resharding( + state_dict: Dict[str, ShardedTensor], + ordered_shard_names_and_lengths: List[Tuple[str, int]], + output_tensor: torch.Tensor, + new_sharding_params: Dict[str, ParameterSharding], + curr_rank: int, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Dict[str, ShardedTensor]: + """ + Updates and returns the given state_dict with new placements and + local_shards based on the output tensor of the AllToAll collective. + + Args: + state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards. + + shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1 + that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and + sizes by rank, then shard order. + + output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation. + + new_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + This should only contain shard names that were updated during the AllToAll operation. + + curr_rank (int): The current rank of the process in the distributed environment. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. + """ + slice_index = 0 + + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + + for shard_name, shard_size in ordered_shard_names_and_lengths: + end_slice_index = slice_index + shard_size + shard_name_to_local_output_tensor[shard_name] = output_tensor[ + slice_index:end_slice_index + ].T + slice_index = end_slice_index + + for shard_name, param in new_sharding_params.items(): + extended_name = extend_shard_name(shard_name) + # pyre-ignore + for i in range(len(param.ranks)): + # pyre-ignore + r = param.ranks[i] + sharded_t = state_dict[extended_name] + # Update placements + sharded_t.metadata().shards_metadata[i].placement = ( + torch.distributed._remote_device(f"rank:{r}/cuda:{r}") + ) + if r == curr_rank: + assert len(output_tensor) > 0 + # slice output tensor for correct size. + sharded_t._local_shards = [ + Shard( + tensor=shard_name_to_local_output_tensor[shard_name], + metadata=state_dict[extended_name] + .metadata() + .shards_metadata[i], + ) + ] + break + else: + sharded_t._local_shards = [] + + return state_dict + + +def update_module_sharding_plan( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + changed_sharding_params: Dict[str, ParameterSharding], +) -> None: + if not hasattr(module, "module_sharding_plan"): + return + + # pyre-ignore + current_plan: Dict[str, ParameterSharding] = module.module_sharding_plan + for table_name, param_sharding in changed_sharding_params.items(): + current_plan[table_name] = param_sharding + return diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py new file mode 100644 index 000000000..88edbbe87 --- /dev/null +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist +from fbgemm_gpu.permute_pooled_embedding_modules_split import ( + PermutePooledEmbeddingsSplit, +) +from torch.distributed._tensor import Replicate, Shard +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) +from torchrec.distributed.dist_data import ( + PooledEmbeddingsAllToAll, + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, + BaseEmbeddingLookup, + BaseSparseFeaturesDist, + EmbeddingSharding, + EmbeddingShardingContext, + EmbeddingShardingInfo, + group_tables, +) +from torchrec.distributed.embedding_types import ( + BaseGroupedFeatureProcessor, + DTensorMetadata, + EmbeddingComputeKernel, + GroupedEmbeddingConfig, + ShardedEmbeddingTable, +) +from torchrec.distributed.sharding.twrw_sharding import TwRwSparseFeaturesDist +from torchrec.distributed.types import ( + Awaitable, + CommOp, + QuantizedCommCodecs, + ShardedTensorMetadata, + ShardingEnv, + ShardingEnv2D, + ShardingType, + ShardMetadata, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Multistreamable + +C = TypeVar("C", bound=Multistreamable) +F = TypeVar("F", bound=Multistreamable) +T = TypeVar("T") +W = TypeVar("W") + + +class BaseGridEmbeddingSharding(EmbeddingSharding[C, F, T, W]): + """ + Base class for grid sharding. + """ + + def __init__( + self, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + need_pos: bool = False, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._env: ShardingEnv = env + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + # pyre-ignore[16] + self._env.sharding_pg + if self._is_2D_parallel + else self._env.process_group + ) + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._need_pos = need_pos + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._embedding_order: List[int] = [] + + self._combined_embedding_names: List[str] = [] + self._combined_embedding_dims: List[int] = [] + + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-fixme[6] + self._env, + device=device, + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) + self._intra_pg: Optional[dist.ProcessGroup] = intra_pg + self._cross_pg: Optional[dist.ProcessGroup] = cross_pg + self._local_size: int = ( + intra_pg.size() if intra_pg else get_local_size(self._world_size) + ) + + sharded_tables_per_rank = self._shard(sharding_infos) + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_node: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs_per_node = [ + self._grouped_embedding_configs_per_rank[rank] + for rank in range(self._world_size) + if rank % self._local_size == 0 + ] + self._has_feature_processor: bool = False + for group_config in self._grouped_embedding_configs_per_rank[ + self._rank // self._local_size + ]: + if group_config.has_feature_processor: + self._has_feature_processor = True + + self._init_combined_embeddings() + + def _init_combined_embeddings(self) -> None: + """ + Initializes combined embeddings, similar to the CW sharding implementation, + but in this case the CW shard is treated on a per node basis and not per rank. + """ + embedding_names = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + + embedding_dims = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + + embedding_shard_metadata = self.embedding_shard_metadata() + + embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {} + for i, (name, metadata) in enumerate( + zip(embedding_names, embedding_shard_metadata) + ): + if name not in embedding_name_to_index_offset_tuples: + embedding_name_to_index_offset_tuples[name] = [] + # find index of each of the offset by column (CW sharding so only col dim changes) + embedding_name_to_index_offset_tuples[name].append( + (i, metadata.shard_offsets[1] if metadata is not None else 0) + ) + + # sort the index offset tuples by offset and then grab the associated index + embedding_name_to_index: Dict[str, List[int]] = {} + for name, index_offset_tuples in embedding_name_to_index_offset_tuples.items(): + embedding_name_to_index[name] = [ + idx_off_tuple[0] + for idx_off_tuple in sorted( + index_offset_tuples, + key=lambda idx_off_tuple: idx_off_tuple[1], + ) + ] + + combined_embedding_names: List[str] = [] + seen_embedding_names: Set[str] = set() + + for name in embedding_names: + if name not in seen_embedding_names: + combined_embedding_names.append(name) + seen_embedding_names.add(name) + + combined_embedding_dims: List[int] = [] + + embedding_order: List[int] = [] + for name in combined_embedding_names: + combined_embedding_dims.append( + sum([embedding_dims[idx] for idx in embedding_name_to_index[name]]) + ) + embedding_order.extend(embedding_name_to_index[name]) + + self._embedding_names: List[str] = embedding_names + self._embedding_dims: List[int] = embedding_dims + self._embedding_order: List[int] = embedding_order + + self._combined_embedding_names: List[str] = combined_embedding_names + self._combined_embedding_dims: List[int] = combined_embedding_dims + + def _shard( + self, + sharding_infos: List[EmbeddingShardingInfo], + ) -> List[List[ShardedEmbeddingTable]]: + """ + Shards the embedding tables. + This method takes the sharding infos and returns a list of lists of + sharded embedding tables, where each inner list represents the tables + for a specific rank. + + Args: + sharding_infos (List[EmbeddingShardingInfo]): The sharding infos. + Returns: + List[List[ShardedEmbeddingTable]]: The sharded embedding tables. + """ + world_size = self._world_size + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for _ in range(world_size) + ] + for info in sharding_infos: + # pyre-fixme [16] + shards = info.param_sharding.sharding_spec.shards + + # construct the global sharded_tensor_metadata + global_metadata = ShardedTensorMetadata( + shards_metadata=shards, + size=torch.Size( + [ + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ] + ), + ) + + dtensor_metadata = None + if self._env.output_dtensor: + placements = ( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + + # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards + # pyre-fixme [6] + for i, rank in enumerate(info.param_sharding.ranks): + rank = ( + # pyre-ignore[16] + self._env.remap_rank(rank, ShardingType.GRID_SHARD) + if self._is_2D_parallel + else rank + ) + tables_per_rank[rank].append( + ShardedEmbeddingTable( + num_embeddings=info.embedding_config.num_embeddings, + embedding_dim=info.embedding_config.embedding_dim, + name=info.embedding_config.name, + embedding_names=info.embedding_config.embedding_names, + data_type=info.embedding_config.data_type, + feature_names=info.embedding_config.feature_names, + pooling=info.embedding_config.pooling, + is_weighted=info.embedding_config.is_weighted, + has_feature_processor=info.embedding_config.has_feature_processor, + local_rows=shards[i].shard_sizes[0], + local_cols=shards[i].shard_sizes[1], + compute_kernel=EmbeddingComputeKernel( + info.param_sharding.compute_kernel + ), + local_metadata=shards[i], + global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, + weight_init_max=info.embedding_config.weight_init_max, + weight_init_min=info.embedding_config.weight_init_min, + fused_params=info.fused_params, + ) + ) + + return tables_per_rank + + def embedding_dims(self) -> List[int]: + return self._combined_embedding_dims + + def embedding_names(self) -> List[str]: + return self._combined_embedding_names + + def embedding_names_per_rank(self) -> List[List[str]]: + raise NotImplementedError + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_shard_metadata.extend(config.embedding_shard_metadata()) + return embedding_shard_metadata + + def feature_names(self) -> List[str]: + feature_names = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_names.extend(config.feature_names()) + return feature_names + + def _get_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_hash_sizes.extend(config.feature_hash_sizes()) + return feature_hash_sizes + + def _dim_sum_per_node(self) -> List[int]: + dim_sum_per_node = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + dim_sum = 0 + for grouped_config in grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + dim_sum_per_node.append(dim_sum) + return dim_sum_per_node + + def _emb_dim_per_node_per_feature(self) -> List[List[int]]: + emb_dim_per_node_per_feature = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + emb_dim_per_feature = [] + for grouped_config in grouped_embedding_configs: + emb_dim_per_feature += grouped_config.embedding_dims() + emb_dim_per_node_per_feature.append(emb_dim_per_feature) + return emb_dim_per_node_per_feature + + def _features_per_rank( + self, group: List[List[GroupedEmbeddingConfig]] + ) -> List[int]: + features_per_rank = [] + for grouped_embedding_configs in group: + num_features = 0 + for grouped_config in grouped_embedding_configs: + num_features += grouped_config.num_features() + features_per_rank.append(num_features) + return features_per_rank + + +class GridPooledEmbeddingDist( + BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] +): + def __init__( + self, + rank: int, + cross_pg: dist.ProcessGroup, + intra_pg: dist.ProcessGroup, + dim_sum_per_node: List[int], + emb_dim_per_node_per_feature: List[List[int]], + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, + ) -> None: + super().__init__() + self._rank = rank + self._intra_pg: dist.ProcessGroup = intra_pg + self._cross_pg: dist.ProcessGroup = cross_pg + self._dim_sum_per_node = dim_sum_per_node + self._emb_dim_per_node_per_feature = emb_dim_per_node_per_feature + self._device = device + self._intra_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None + ) + if qcomm_codecs_registry + else None + ) + self._cross_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get(CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None) + if qcomm_codecs_registry + else None + ) + self._intra_dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._cross_dist: Optional[ + Union[ + PooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, + ] + ] = None + self._callbacks = callbacks + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[EmbeddingShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs reduce-scatter pooled operation on pooled embeddings tensor followed by + AlltoAll pooled operation. + + Args: + local_embs (torch.Tensor): pooled embeddings tensor to distribute. + + Returns: + Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. + """ + if self._intra_dist is None or self._cross_dist is None: + self._create_output_dist_modules(sharding_ctx) + local_rank = self._rank % self._intra_pg.size() + if sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1: + # preprocess batch_size_per_rank + ( + batch_size_per_rank_by_cross_group, + batch_size_sum_by_cross_group, + ) = self._preprocess_batch_size_per_rank( + self._intra_pg.size(), + self._cross_pg.size(), + sharding_ctx.batch_size_per_rank, + ) + # Perform ReduceScatterV within one host + rs_result = cast(PooledEmbeddingsReduceScatter, self._intra_dist)( + local_embs, input_splits=batch_size_sum_by_cross_group + ).wait() + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + rs_result, + batch_size_per_rank=batch_size_per_rank_by_cross_group[local_rank], + ) + else: + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + cast(PooledEmbeddingsReduceScatter, self._intra_dist)(local_embs).wait() + ) + + def _preprocess_batch_size_per_rank( + self, local_size: int, nodes: int, batch_size_per_rank: List[int] + ) -> Tuple[List[List[int]], List[int]]: + """ + Reorders `batch_size_per_rank` so it's aligned with reordered features after + AlltoAll. + """ + batch_size_per_rank_by_cross_group: List[List[int]] = [] + batch_size_sum_by_cross_group: List[int] = [] + for local_rank in range(local_size): + batch_size_per_rank_: List[int] = [] + batch_size_sum = 0 + for node in range(nodes): + batch_size_per_rank_.append( + batch_size_per_rank[local_rank + node * local_size] + ) + batch_size_sum += batch_size_per_rank[local_rank + node * local_size] + batch_size_per_rank_by_cross_group.append(batch_size_per_rank_) + batch_size_sum_by_cross_group.append(batch_size_sum) + + return batch_size_per_rank_by_cross_group, batch_size_sum_by_cross_group + + def _create_output_dist_modules( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + self._intra_dist = PooledEmbeddingsReduceScatter( + pg=self._intra_pg, + codecs=self._intra_codecs, + ) + self._cross_dist = PooledEmbeddingsAllToAll( + pg=self._cross_pg, + dim_sum_per_rank=self._dim_sum_per_node, + device=self._device, + codecs=self._cross_codecs, + callbacks=self._callbacks, + ) + + +class GridPooledEmbeddingSharding( + BaseGridEmbeddingSharding[ + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ] +): + """ + Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node + """ + + def create_input_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + features_per_rank = self._features_per_rank( + self._grouped_embedding_configs_per_rank + ) + feature_hash_sizes = self._get_feature_hash_sizes() + assert self._pg is not None + assert self._intra_pg is not None + return TwRwSparseFeaturesDist( + pg=self._pg, + local_size=self._intra_pg.size(), + features_per_rank=features_per_rank, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + has_feature_processor=self._has_feature_processor, + need_pos=self._need_pos, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs_per_rank[self._rank], + pg=self._pg, + device=device if device is not None else self._device, + feature_processor=feature_processor, + sharding_type=ShardingType.TABLE_ROW_WISE, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: + embedding_permute_op: Optional[PermutePooledEmbeddingsSplit] = None + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None + if self._embedding_order != list(range(len(self._embedding_order))): + assert len(self._embedding_order) == len(self._embedding_dims) + embedding_permute_op = PermutePooledEmbeddingsSplit( + self._embedding_dims, self._embedding_order, device=self._device + ) + callbacks = [embedding_permute_op] + return GridPooledEmbeddingDist( + rank=self._rank, + cross_pg=cast(dist.ProcessGroup, self._cross_pg), + intra_pg=cast(dist.ProcessGroup, self._intra_pg), + dim_sum_per_node=self._dim_sum_per_node(), + emb_dim_per_node_per_feature=self._emb_dim_per_node_per_feature(), + device=device if device is not None else self._device, + qcomm_codecs_registry=self.qcomm_codecs_registry, + callbacks=callbacks, + ) diff --git a/torchrec/distributed/sharding/rw_kjt_pool_sharding.py b/torchrec/distributed/sharding/rw_kjt_pool_sharding.py new file mode 100644 index 000000000..1d5abe9c3 --- /dev/null +++ b/torchrec/distributed/sharding/rw_kjt_pool_sharding.py @@ -0,0 +1,543 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Iterable, List, Tuple + +import torch +import torch.distributed as dist + +from torch.distributed._shard.sharded_tensor import Shard, ShardMetadata + +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torchrec.distributed.comm import ( + get_local_rank, + get_local_size, + intra_and_cross_node_pg, +) +from torchrec.distributed.dist_data import JaggedTensorAllToAll +from torchrec.distributed.sharding.rw_pool_sharding import ( + InferRwObjectPoolInputDist, + RwObjectPoolIDsDist, +) +from torchrec.distributed.tensor_sharding import ( + InferObjectPoolSharding, + ObjectPoolReplicatedRwShardingContext, + ObjectPoolRwShardingContext, + ObjectPoolSharding, +) +from torchrec.distributed.types import Awaitable, ShardingEnv +from torchrec.modules.object_pool_lookups import KeyedJaggedTensorPoolLookup +from torchrec.modules.utils import jagged_index_select_with_empty +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +class RwKeyedJaggedTensorPoolLookupValuesDist(torch.nn.Module): + """ + Module to distribute KeyedJaggedTensor to all ranks after local pool lookup + + Args: + num_features (int): number of features in KeyedJaggedTensor to be distributed + env (ShardingEnv): Sharding environment with info such as rank and world size + + Example: + dist = RwKeyedJaggedTensorPoolLookupValuesDist(num_features=2, env=env) + + # on rank 0, sends 1 and receives 2 batches + ctx = ObjectPoolRwShardingContext(num_ids_each_rank_to_send=1, num_ids_each_rank_to_receive=2) + jt = JaggedTensor(values=[2,3,2], lengths=[2,1]) + + # on rank 1, sends 2 and receives 1 batches + ctx = ObjectPoolRwShardingContext(num_ids_each_rank_to_send=2, num_ids_each_rank_to_receive=1) + jt = JaggedTensor(values=[1,1,5,2,8], lengths=[2,2,1]) + + rank0_out = dist(ctx, jt).wait() + + # rank0_out is + # JaggedTensor(values=[2,3,1,1,5,2], lengths=[2,2,2]) + + rank1_out = dist(ctx, jt).wait() + # rank1_out is + # JaggedTensor(values=[2,8], lengths=[1,1]) + + """ + + def __init__( + self, + num_features: int, + env: ShardingEnv, + ) -> None: + super().__init__() + self._sharding_env = env + self._num_features = num_features + + def forward( + self, + ctx: ObjectPoolRwShardingContext, + jagged_tensor: JaggedTensor, + ) -> Awaitable[JaggedTensor]: + """ + Sends JaggedTensor to relevant `ProcessGroup` ranks. + + Args: + ctx (ObjectPoolRwShardingContext): Context for RW sharding, containing + number of items to send and receive from each rank. + jagged_tensor (JaggedTensor): JaggedTensor to distribute. This JT is + constructed from flattening a KeyedJaggedTensor. + + Returns: + Awaitable[JaggedTensor]: awaitable of `JaggedTensor` + """ + return JaggedTensorAllToAll( + jt=jagged_tensor, + # pyre-ignore + num_items_to_send=ctx.num_ids_each_rank_to_send * self._num_features, + # pyre-ignore + num_items_to_receive=ctx.num_ids_each_rank_to_receive * self._num_features, + # pyre-ignore + pg=self._sharding_env.process_group, + ) + + +class RwKeyedJaggedTensorPoolUpdateValuesDist(torch.nn.Module): + """ + Module to distribute updated KeyedJaggedTensor to all ranks after local pool update + + Args: + num_features (int): number of features in KeyedJaggedTensor to be distributed + env (ShardingEnv): Sharding environment with info such as rank and world size + device (torch.device): Device on which to allocate tensors + num_replicas (int): number of times KJT should be replicated across ranks in case + of replicated row-wise sharding. Defaults to 1 (no replica). + + Example: + keys=['A','B'] + dist = RwKeyedJaggedTensorPoolUpdateValuesDist(num_features=len(keys), env=env) + ctx = ObjectPoolRwShardingContext( + num_ids_each_rank_to_send=1, + num_ids_each_rank_to_receive=1, + ) + awaitable = dist(rank0_input, ctx) + + # where: + # rank0_input is KeyedJaggedTensor holding + + # 0 1 + # 'A' [A.V0] None + # 'B' None [B.V0] + + # rank1_input is KeyedJaggedTensor holding + + # 0 1 + # 'A' [A.V3] [A.V4] + # 'B' None [B.V2] + + rank0_output = awaitable.wait() + + # where: + # rank0_output is JaggedTensor holding + + values = [A.V0, A.V3] + lengths = [1,0,1,0] + + # rank1_output is JaggedTensor holding + + values = [B.V0, A.V4, B.V2] + lengths = [0,1,1,1] + """ + + def __init__( + self, + num_features: int, + env: ShardingEnv, + device: torch.device, + num_replicas: int = 1, + ) -> None: + super().__init__() + self._env = env + self._num_features = num_features + self._num_replicas = num_replicas + self._device = device + + def forward( + self, + values: KeyedJaggedTensor, + ctx: ObjectPoolRwShardingContext, + ) -> Awaitable[JaggedTensor]: + """ + Sends tensor to relevant `ProcessGroup` ranks. + + Args: + values (KeyedJaggedTensor): KJT to distribute + ctx (ObjectPoolRwShardingContext): Context for RW sharding, containing + indices along batch dimension to permute KJT before A2A, as well as + number of items to send and receive from each rank. + + Returns: + Awaitable[JaggedTensor]: awaitable of `JaggedTensor` from which KJT can be + reconstructed. + + """ + + kjt = values + permute_idx = ctx.unbucketize_permute + + # Below code lets us select values out from a KJT in a row manner format for example + # KJT + # f1 [0,1] [2,3] + # f2 [3,4,5] [6] + # the values come in as 0,1,2,3,4,5,6, however, we need it in feature order e.g. + # 0,1,3,4,5,2,3,6 to more efficiently to the all to alls + # we can use jagged index select to these, e.g. we need the indices to come in order of + # 0,2,1,3. + # this can be done by first chunking viewing as [[0,1][2,3]], + # then taking a transpose and flatten => [0,2,1,3] + + arange_idx = torch.arange( + kjt.stride() * self._num_features, device=self._device + ) + jagged_idx = arange_idx.view(self._num_features, -1).t() + jt_lengths_in_order_for_a2a = jagged_idx[permute_idx].flatten() + + lengths_to_send = kjt.lengths()[jt_lengths_in_order_for_a2a] + kjt_values_to_send_offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum( + lengths_to_send + ) + kjt_values_to_send = jagged_index_select_with_empty( + kjt.values().unsqueeze(-1), + jt_lengths_in_order_for_a2a, + kjt.offsets()[1:], + kjt_values_to_send_offsets, + ) + kjt_values_to_send = kjt_values_to_send.flatten() + + kjt_weights_to_send = None + if kjt.weights_or_none() is not None: + kjt_weights_to_send = jagged_index_select_with_empty( + kjt.weights().unsqueeze(-1), + jt_lengths_in_order_for_a2a, + kjt.offsets()[1:], + kjt_values_to_send_offsets, + ) + + if self._num_replicas > 1: + kjt_values_to_send = kjt_values_to_send.repeat(self._num_replicas) + lengths_to_send = lengths_to_send.flatten().repeat(self._num_replicas) + if kjt_weights_to_send is not None: + kjt_weights_to_send = kjt_weights_to_send.repeat(self._num_replicas) + + jt_all_to_all = JaggedTensorAllToAll( + JaggedTensor( + values=kjt_values_to_send, + lengths=lengths_to_send, + weights=kjt_weights_to_send, + ), + # pyre-ignore + num_items_to_send=ctx.num_ids_each_rank_to_send * self._num_features, + # pyre-ignore + num_items_to_receive=ctx.num_ids_each_rank_to_receive * self._num_features, + # pyre-ignore + pg=self._env.process_group, + ) + + return jt_all_to_all + + +class KeyedJaggedTensorPoolRwSharding(ObjectPoolSharding): + def __init__( + self, + env: ShardingEnv, + device: torch.device, + pool_size: int, + num_features: int, + ) -> None: + self._env = env + # pyre-ignore + self._pg: dist.ProcessGroup = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._pool_size = pool_size + + self._block_size: int = ( + pool_size + self._env.world_size - 1 + ) // self._env.world_size + + self.local_pool_size: int = ( + self._block_size + if self._env.rank != self._env.world_size - 1 + else pool_size - self._block_size * (self._env.world_size - 1) + ) + + self._block_size_t: torch.Tensor = torch.tensor( + [ + self._block_size, + ], + dtype=torch.long, + device=self._device, + ) + self._num_features = num_features + + def create_update_ids_dist( + self, + ) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist(self._pg, is_update=True) + + def create_update_values_dist( + self, + ) -> RwKeyedJaggedTensorPoolUpdateValuesDist: + return RwKeyedJaggedTensorPoolUpdateValuesDist( + num_features=self._num_features, + env=self._env, + device=self._device, + ) + + def create_lookup_ids_dist( + self, + ) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist(self._pg, is_update=False) + + def create_lookup_values_dist(self) -> RwKeyedJaggedTensorPoolLookupValuesDist: + return RwKeyedJaggedTensorPoolLookupValuesDist( + num_features=self._num_features, env=self._env + ) + + def get_sharded_states_to_register( + self, lookup: KeyedJaggedTensorPoolLookup + ) -> Iterable[Tuple[str, torch.Tensor]]: + for fqn, tensor in lookup.states_to_register(): + yield fqn, ShardedTensor._init_from_local_shards( + [ + Shard( + tensor=tensor, + metadata=ShardMetadata( + shard_offsets=[ + self._env.rank * self._block_size, + 0, + ], + shard_sizes=[ + tensor.shape[0], + tensor.shape[1], + ], + placement=f"rank:{self._env.rank}/{str(tensor.device)}", + ), + ) + ], + torch.Size([self._pool_size, tensor.shape[1]]), + process_group=self._env.process_group, + ) + + def create_context(self) -> ObjectPoolRwShardingContext: + return ObjectPoolRwShardingContext(block_size=self._block_size_t) + + +@torch.fx.wrap +def _cat_if_multiple(tensor_list: List[torch.Tensor]) -> torch.Tensor: + if len(tensor_list) == 1: + return tensor_list[0].flatten() + else: + return torch.cat([x.flatten() for x in tensor_list]) + + +class InferRwKeyedJaggedTensorPoolOutputDist(torch.nn.Module): + """ + Redistributes jaggd tensors in RW fashion with an AlltoOne operation. + + Inference assumes that this is called on a single rank, but jagged tensors are placed + on different devices. + + Args: + env (ShardingEnv): Sharding environment with info such as rank and world size + device (torch.device): device on which the tensors will be communicated to. + + Example: + device_cpu = torch.device("cpu") + dist = InferRwKeyedJaggedTensorPoolOutputDist(env, device_cpu) + jagged_tensors = [ + JaggedTensor(values=torch.tensor([1,2,3]), lengths=torch.tensor([1,1,1]), device=torch.device("rank:0/cuda:0")), + JaggedTensor(values=torch.tensor([5,5,5]), lengths=torch.tensor([2,1]), device=torch.device("rank:1/cuda:0")), + ] + jt = dist(jagged_tensors) + + # jt has values [1,2,3,5,5,5] and lengths [1,1,1,2,1] + """ + + def __init__( + self, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + self._sharding_env = env + self._device = device + + def forward( + self, + jagged_tensors: List[JaggedTensor], + ) -> JaggedTensor: + """ + Performs AlltoOne operation on list of jagged tensors placed on different + devices and returns merged jagged tensor. + + Args: + jagged_tensors (List[JaggedTensor]): List of jagged tensors placed on + different ranks + + Returns: JaggedTensor + """ + lengths = [jt.lengths() for jt in jagged_tensors] + values = [jt.values() for jt in jagged_tensors] + values = _cat_if_multiple( + torch.ops.fbgemm.all_to_one_device( + [v.reshape(-1, v.shape[0]) for v in values], self._device + ) + ) + lengths = _cat_if_multiple( + torch.ops.fbgemm.all_to_one_device( + [x.reshape(-1, x.shape[0]) for x in lengths], + self._device, + ) + ) + + return JaggedTensor(values=values, lengths=lengths) + + +class InferRwKeyedJaggedTensorPoolSharding(InferObjectPoolSharding): + def __init__( + self, + pool_size: int, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__(pool_size, env, device) + + def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist: + return InferRwObjectPoolInputDist( + self._env, device=self._device, block_size=self._block_size_t + ) + + def create_lookup_values_dist(self) -> InferRwKeyedJaggedTensorPoolOutputDist: + return InferRwKeyedJaggedTensorPoolOutputDist( + env=self._env, device=self._device + ) + + +class KeyedJaggedTensorPoolRwReplicatedSharding(ObjectPoolSharding): + def __init__( + self, + env: ShardingEnv, + device: torch.device, + pool_size: int, + num_features: int, + ) -> None: + self._env = env + # pyre-ignore + self._pg: dist.ProcessGroup = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._local_world_size: int = get_local_size(self._world_size) + + self._pool_size = pool_size + + self._num_features = num_features + + intra_pg, _cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) + + # pyre-ignore + self._intra_pg: dist.ProcessGroup = intra_pg + + self._local_rank: int = get_local_rank(self._world_size) + + self._block_size: int = ( + pool_size + self._local_world_size - 1 + ) // self._local_world_size + + self.local_pool_size: int = ( + self._block_size + if self._local_rank != self._local_world_size - 1 + else pool_size - self._block_size * (self._local_world_size - 1) + ) + + self._block_size_t: torch.Tensor = torch.tensor( + [ + self._block_size, + ], + dtype=torch.long, + device=self._device, + ) + + self._local_env = ShardingEnv( + world_size=dist.get_world_size(self._intra_pg), + rank=dist.get_rank(self._intra_pg), + pg=self._intra_pg, + ) + + self._num_replicas: int = self._world_size // self._local_world_size + + def create_update_ids_dist( + self, + ) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist( + self._pg, + is_update=True, + bucketize_world_size=self._intra_pg.size(), + num_replicas=self._num_replicas, + ) + + def create_update_values_dist( + self, + ) -> RwKeyedJaggedTensorPoolUpdateValuesDist: + return RwKeyedJaggedTensorPoolUpdateValuesDist( + num_features=self._num_features, + env=self._env, + num_replicas=self._num_replicas, + device=self._device, + ) + + def create_lookup_ids_dist( + self, + ) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist(self._intra_pg, is_update=False) + + def create_lookup_values_dist( + self, + ) -> RwKeyedJaggedTensorPoolLookupValuesDist: + return RwKeyedJaggedTensorPoolLookupValuesDist( + num_features=self._num_features, env=self._local_env + ) + + def get_sharded_states_to_register( + self, lookup: KeyedJaggedTensorPoolLookup + ) -> Iterable[Tuple[str, torch.Tensor]]: + for fqn, tensor in lookup.states_to_register(): + yield fqn, ShardedTensor._init_from_local_shards( + [ + Shard( + tensor=tensor, + metadata=ShardMetadata( + shard_offsets=[ + self._local_env.rank * self._block_size, + 0, + ], + shard_sizes=[ + tensor.shape[0], + tensor.shape[1], + ], + placement=f"rank:{self._local_env.rank}/{str(tensor.device)}", + ), + ) + ], + torch.Size([self._pool_size, tensor.shape[1]]), + process_group=self._local_env.process_group, + ) + + def create_context(self) -> ObjectPoolReplicatedRwShardingContext: + return ObjectPoolReplicatedRwShardingContext(block_size=self._block_size_t) diff --git a/torchrec/distributed/sharding/rw_pool_sharding.py b/torchrec/distributed/sharding/rw_pool_sharding.py new file mode 100644 index 000000000..2e0823dd1 --- /dev/null +++ b/torchrec/distributed/sharding/rw_pool_sharding.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from torchrec.distributed.dist_data import TensorAllToAll +from torchrec.distributed.tensor_sharding import ObjectPoolRwShardingContext +from torchrec.distributed.types import Awaitable, ShardingEnv + +NUM_THREADS_BUCKETIZE = 32 + + +class RwObjectPoolIDsDist(torch.nn.Module): + """ + Redistribute torch.Tensor values containing IDs for sharded object pools + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + is_update (bool): Boolean indicating whether this is an update or not. Defaults + to False. + + During update, number of values to send to each rank is determined by the + number of IDs in each bucket, while no. of values to receive from each rank is + determined by the no. of IDs for the current rank to be sent by all + other ranks. + + During lookup, we first receive IDs to lookup from other ranks before + distributing the looked up values, so the no. of values to send to each rank + is collected from other ranks after IDs All2All. Conversely, no. of values + to receive from each rank is determined by the number of IDs in each bucket. + This is opposite of what happens during an update, so the code is shared. + + bucketize_world_size (Optional[int]): Number of buckets to bucketize IDs into. + Defaults to `None` in which case the world size of the ProcessGroup is used. + + num_replicas (int): number of replicas of objects (tensor/KJT) to keep across + ranks in case of replicated RW sharding. Defaults to 1. + + Example: + dist = RwObjectPoolIDsDist(pg=pg, is_update=True, bucketize_world_size=2) + ids = torch.Tensor([0,2,1,4,5]) + out = dist(ctx,ids).wait().wait() + + # values 2 and 1 need to be swapped + ctx.unbucketize_permute == torch.tensor([0,2,1,3,4]) + ctx.num_ids_each_rank_to_send = torch.tensor([2,3]) + """ + + def __init__( + self, + pg: dist.ProcessGroup, + is_update: bool = True, + bucketize_world_size: Optional[int] = None, + num_replicas: int = 1, + ) -> None: + super().__init__() + self._world_size: int = pg.size() + self._dist = TensorAllToAll(pg=pg) + self._is_update: bool = is_update + + self._num_replicas = num_replicas + self._bucketize_world_size: int = bucketize_world_size or pg.size() + + def forward( + self, + ctx: ObjectPoolRwShardingContext, + ids: torch.Tensor, + ) -> Awaitable[Awaitable[torch.Tensor]]: + """ + Bucketizes IDs into `world_size` buckets and distributes them to other ranks. + + Args: + ctx (Optional[EmbeddingShardingContext]): shared context from + RW sharding operation. Number of ids to receive and send per rank + is stored in this context + ids (torch.Tensor): 1D tensor containing ids to be distributed + + Returns: + Awaitable[Awaitable[torch.Tensor]]: awaitable of tensor awaitable. + + """ + + num_ids = ids.shape[0] + num_threads = NUM_THREADS_BUCKETIZE + quot, rem = divmod(num_ids, num_threads) + lengths = [quot] * num_threads + for i in range(rem): + lengths[i] += 1 + lengths = torch.tensor(lengths, device=ids.device, dtype=torch.int) + + ( + bucketized_lengths, + bucketized_indices, + _bucketized_weights, + _bucketize_permute, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + lengths=lengths, + indices=ids, + bucketize_pos=False, + sequence=True, + # pyre-ignore + block_sizes=ctx.block_size.to(ids.dtype), + my_size=self._bucketize_world_size, + weights=None, + ) + + bucketized_lengths = ( + bucketized_lengths.reshape(self._bucketize_world_size, -1).sum(dim=1).int() + ) + + ctx.ids_before_input_dist = ids + ctx.unbucketize_permute = unbucketize_permute + # not needed, see if we can remove + ctx.bucketize_permute = None + + if self._num_replicas > 1: + bucketized_indices = bucketized_indices.repeat(self._num_replicas) + bucketized_lengths = bucketized_lengths.repeat(self._num_replicas) + + await_dist_ids = self._dist( + input=bucketized_indices, + splits=bucketized_lengths, + ) + + if self._is_update: + ctx.num_ids_each_rank_to_send = bucketized_lengths + ctx.num_ids_each_rank_to_receive = await_dist_ids._output_splits + else: + ctx.num_ids_each_rank_to_send = await_dist_ids._output_splits + ctx.num_ids_each_rank_to_receive = bucketized_lengths + + return await_dist_ids + + +@torch.fx.wrap +def _get_bucketize_shape(ids: torch.Tensor, device: torch.device) -> torch.Tensor: + return torch.tensor([ids.size(dim=0)], device=device, dtype=torch.long) + + +@torch.fx.wrap +def _get_unbucketize_permute_index( + unbucketize_permute: Optional[torch.Tensor], +) -> torch.Tensor: + assert unbucketize_permute is not None, "unbucketize permute must not be None" + _, index = unbucketize_permute.sort() + return index + + +class InferRwObjectPoolInputDist(torch.nn.Module): + """ + Redistribute torch.Tensor values containing IDs for sharded object pools for inference + + Args: + env (ShardingEnv): Sharding environment containing rank, world size, etc + device (torch.device): device on which the tensors will be communicated to during + lookup and update + block_size (torch.Tensor): tensor containing block sizes for each rank. + e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank + 0, 100-199 to rank 1, and so on. + + Example: + device = torch.device("cpu") + dist = InferRwObjectPoolInputDist(env=env, device=device, block_size=torch.tensor(100)) + ids = torch.Tensor([0,99,100,111]) + list_ids, permute = dist.lookup(ids) + + # list_ids == [torch.Tensor([0,99], device="cpu"), torch.Tensor([100,111], device="cpu)])] + """ + + _world_size: int + _device: torch.device + _block_size: torch.Tensor + + def __init__( + self, + env: ShardingEnv, + device: torch.device, + block_size: torch.Tensor, + ) -> None: + super().__init__() + self._world_size = env.world_size + self._device = device + self._block_size = block_size + + def forward( + self, + ids: torch.Tensor, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: + """ + Bucketizes ids tensor into a list of tensors each containing ids + for the corresponding rank. Places each tensor on the appropriate device. + + Args: + ids (torch.Tensor): Tensor with ids + + Returns: + Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors + for each rank given the bucket sizes, and the tensor containing indices + to permute the ids to get the original order before bucketization. + """ + ( + bucketized_lengths, + bucketized_indices, + _bucketized_weights, + _bucketize_permute, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + _get_bucketize_shape(ids, ids.device), + ids.long(), + bucketize_pos=False, + sequence=True, + block_sizes=self._block_size.long(), + my_size=self._world_size, + weights=None, + ) + + id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths) + dist_ids = [] + for rank in range(self._world_size): + offset = id_offsets[rank] + next_offset = id_offsets[rank + 1] + ids_for_rank = bucketized_indices[offset:next_offset] + dist_ids.append( + ids_for_rank + if self._device == torch.device("cpu") + else ids_for_rank.to(torch.device(f"cuda:{rank}"), non_blocking=True) + ) + + assert unbucketize_permute is not None, "unbucketize permute must not be None" + return dist_ids, unbucketize_permute + + def update( + self, + ids: torch.Tensor, + values: torch.Tensor, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + """ + Split the values into same buckets are IDs and place on the appropriate device + for inference. + + Args: + ids (torch.Tensor): Tensor with ids + values (torch.Tensor): Tensor with values + + Returns: + Tuple[List[torch.Tensor], List[torch.Tensor] torch.Tensor]: Tuple containing + list of ids tensors, list of values tensors, and a tensor containing indices + to permute the ids to get the original order before bucketization. + """ + ( + bucketized_lengths, + bucketized_indices, + _bucketized_weights, + _bucketize_permute, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + _get_bucketize_shape(ids, ids.device), + ids.long(), + bucketize_pos=False, + sequence=True, + block_sizes=self._block_size.long(), + my_size=self._world_size, + weights=None, + ) + + id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths) + + index = _get_unbucketize_permute_index(unbucketize_permute) + unbucketize_values = values[index] + dist_ids = [] + dist_values = [] + for rank in range(self._world_size): + offset = id_offsets[rank] + next_offset = id_offsets[rank + 1] + ids_for_rank = bucketized_indices[offset:next_offset] + values_for_rank = unbucketize_values[offset:next_offset] + dist_ids.append( + ids_for_rank + if self._device == torch.device("cpu") + else ids_for_rank.to(torch.device(f"cuda:{rank}"), non_blocking=True) + ) + dist_values.append( + values_for_rank + if self._device == torch.device("cpu") + else values_for_rank.to(torch.device(f"cuda:{rank}"), non_blocking=True) + ) + + return dist_ids, dist_values, unbucketize_permute diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 243750b95..4029d9aa6 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -5,12 +5,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional +# pyre-strict + +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist -from torchrec.distributed.dist_data import SequenceEmbeddingsAllToAll -from torchrec.distributed.embedding_lookup import GroupedEmbeddingsLookup +from torchrec.distributed.dist_data import ( + SeqEmbeddingsAllToOne, + SequenceEmbeddingsAllToAll, +) +from torchrec.distributed.embedding_lookup import ( + GroupedEmbeddingsLookup, + InferGroupedEmbeddingsLookup, +) from torchrec.distributed.embedding_sharding import ( BaseEmbeddingDist, BaseEmbeddingLookup, @@ -18,14 +26,27 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, - SparseFeatures, + InputDistOutputs, ) from torchrec.distributed.sharding.rw_sharding import ( BaseRwEmbeddingSharding, + get_embedding_shard_metadata, + InferRwSparseFeaturesDist, RwSparseFeaturesDist, ) -from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") class RwSequenceEmbeddingDist( @@ -52,11 +73,13 @@ def __init__( pg, [num_features] * pg.size(), device, - codecs=qcomm_codecs_registry.get( - CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), ) def forward( @@ -89,7 +112,7 @@ def forward( class RwSequenceEmbeddingSharding( BaseRwEmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -100,24 +123,19 @@ class RwSequenceEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: - num_id_list_features = self._get_id_list_features_num() - num_id_score_list_features = self._get_id_score_list_features_num() - id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() - id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + num_features = self._get_num_features() + feature_hash_sizes = self._get_feature_hash_sizes() return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. pg=self._pg, - num_id_list_features=num_id_list_features, - num_id_score_list_features=num_id_score_list_features, - id_list_feature_hash_sizes=id_list_feature_hash_sizes, - id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, device=device if device is not None else self._device, is_sequence=True, has_feature_processor=self._has_feature_processor, need_pos=False, - variable_batch_size=self._variable_batch_size, ) def create_lookup( @@ -140,7 +158,147 @@ def create_output_dist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. self._pg, - self._get_id_list_features_num(), + self._get_num_features(), device if device is not None else self._device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) + + +class InferRwSequenceEmbeddingDist( + BaseEmbeddingDist[ + InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] + ] +): + def __init__( + self, + device: torch.device, + world_size: int, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: + super().__init__() + self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( + device_type_from_sharding_infos + ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + + self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) + + def forward( + self, + local_embs: List[torch.Tensor], + sharding_ctx: Optional[InferSequenceShardingContext] = None, + ) -> List[torch.Tensor]: + assert ( + self._device_type_from_sharding_infos is not None + ), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist" + if isinstance(self._device_type_from_sharding_infos, tuple): + assert sharding_ctx is not None + assert sharding_ctx.embedding_names_per_rank is not None + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + # Here looping through local_embs is also compatible with tracing + # given the number of looks up / shards withing ShardedQuantEmbeddingCollection + # are fixed and local_embs is the output of those looks ups. However, still + # using _device_type_from_sharding_infos to iterate on local_embs list as + # that's a better practice. + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + sharding_ctx.features[i], + # pyre-fixme [16] + sharding_ctx.embedding_names_per_rank[i], + ), + local_embs[i], + ) + ) + non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs) + index = 0 + result = [] + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type == "cpu": + result.append(local_embs[i]) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result + elif self._device_type_from_sharding_infos == "cpu": + # for cpu sharder, output dist should be a no-op + return local_embs + else: + return self._device_dist(local_embs) + + +class InferRwSequenceEmbeddingSharding( + BaseRwEmbeddingSharding[ + InferSequenceShardingContext, + InputDistOutputs, + List[torch.Tensor], + List[torch.Tensor], + ] +): + """ + Shards sequence (unpooled) row-wise, i.e.. a given embedding table is evenly + distributed by rows and table slices are placed on all ranks for inference. + """ + + def create_input_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseSparseFeaturesDist[InputDistOutputs]: + num_features = self._get_num_features() + feature_hash_sizes = self._get_feature_hash_sizes() + + (emb_sharding, is_even_sharding) = get_embedding_shard_metadata( + self._grouped_embedding_configs_per_rank + ) + + return InferRwSparseFeaturesDist( + world_size=self._world_size, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=True, + has_feature_processor=self._has_feature_processor, + need_pos=False, + embedding_shard_metadata=emb_sharding if not is_even_sharding else None, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: + return InferGroupedEmbeddingsLookup( + grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, + world_size=self._world_size, + fused_params=fused_params, + device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[ + InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] + ]: + return InferRwSequenceEmbeddingDist( + device if device is not None else self._device, + self._world_size, + self._device_type_from_sharding_infos, + ) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index a21d8ffa4..b62609da1 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -5,40 +5,61 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, TypeVar +# pyre-strict + +import logging +import math + +from typing import Any, cast, Dict, List, Optional, overload, Tuple, TypeVar, Union import torch import torch.distributed as dist -from torchrec.distributed.dist_data import PooledEmbeddingsReduceScatter -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torch.distributed._tensor.placement_types import Replicate, Shard +from torchrec.distributed.dist_data import ( + EmbeddingsAllToOneReduce, + KJTAllToAll, + KJTOneToAll, + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_lookup import ( + GroupedPooledEmbeddingsLookup, + InferGroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embedding_sharding import ( BaseEmbeddingDist, BaseEmbeddingLookup, BaseSparseFeaturesDist, bucketize_kjt_before_all2all, + bucketize_kjt_inference, EmbeddingSharding, EmbeddingShardingContext, EmbeddingShardingInfo, group_tables, - SparseFeaturesAllToAll, ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, + InputDistOutputs, ShardedEmbeddingTable, - SparseFeatures, ) from torchrec.distributed.types import ( Awaitable, CommOp, + NullShardingContext, QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, + ShardingType, ShardMetadata, ) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable +logger: logging.Logger = logging.getLogger(__name__) C = TypeVar("C", bound=Multistreamable) F = TypeVar("F", bound=Multistreamable) @@ -46,6 +67,40 @@ W = TypeVar("W") +def get_embedding_shard_metadata( + grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]], +) -> Tuple[List[List[int]], bool]: + is_even_sharding: bool = True + world_size = len(grouped_embedding_configs_per_rank) + + def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]: + block_size: int = math.ceil(hash_size / world_size) + last_rank: int = hash_size // block_size + + expected_even_shard_sizes = [block_size] * last_rank + if hash_size % world_size != 0: + expected_even_shard_sizes.append(hash_size - sum(expected_even_shard_sizes)) + return expected_even_shard_sizes + + embed_sharding = [] + for table in grouped_embedding_configs_per_rank[0][0].embedding_tables: + embed_sharding_per_feature = [] + total_rows = 0 + sizes = [] + # pyre-ignore [16]: `Optional` has no attribute `shards_metadata` + for metadata in table.global_metadata.shards_metadata: + embed_sharding_per_feature.append(metadata.shard_offsets[0]) + total_rows += metadata.shard_sizes[0] + sizes.append(metadata.shard_sizes[0]) + embed_sharding_per_feature.append(total_rows) + embed_sharding.extend([embed_sharding_per_feature] * len(table.embedding_names)) + expected_even_sizes = get_even_shard_sizes(total_rows, world_size) + if sizes != expected_even_sizes: + is_even_sharding = False + + return (embed_sharding, is_even_sharding) + + class BaseRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]): """ Base class for row-wise sharding. @@ -58,51 +113,45 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: - super().__init__( - qcomm_codecs_registry=qcomm_codecs_registry, - ) - + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank if device is None: device = torch.device("cpu") - self._device = device + self._device: torch.device = device + self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( + device_type_from_sharding_infos + ) sharded_tables_per_rank = self._shard(sharding_infos) self._need_pos = need_pos - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._score_grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - ( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ) = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[self._rank] - self._score_grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._score_grouped_embedding_configs_per_rank[self._rank] + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[self._rank] + ) self._has_feature_processor: bool = False for group_config in self._grouped_embedding_configs: if group_config.has_feature_processor: self._has_feature_processor = True - self._variable_batch_size = variable_batch_size - def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(self._world_size) + [] for _ in range(self._world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -113,12 +162,37 @@ def _shard( shards_metadata=shards, size=torch.Size( [ - info.embedding_config.num_embeddings, + ( + info.embedding_config.num_embeddings_post_pruning + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), info.embedding_config.embedding_dim, ] ), ) + dtensor_metadata = None + if self._env.output_dtensor: + placements = ( + (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + ( + info.embedding_config.num_embeddings_post_pruning + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + for rank in range(self._world_size): tables_per_rank[rank].append( ShardedEmbeddingTable( @@ -138,9 +212,11 @@ def _shard( ), local_metadata=shards[rank], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, + num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning, ) ) return tables_per_rank @@ -149,67 +225,55 @@ def embedding_dims(self) -> List[int]: embedding_dims = [] for grouped_config in self._grouped_embedding_configs: embedding_dims.extend(grouped_config.embedding_dims()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_dims.extend(grouped_config.embedding_dims()) return embedding_dims def embedding_names(self) -> List[str]: embedding_names = [] for grouped_config in self._grouped_embedding_configs: embedding_names.extend(grouped_config.embedding_names()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_names.extend(grouped_config.embedding_names()) return embedding_names def embedding_names_per_rank(self) -> List[List[str]]: - raise NotImplementedError + embedding_names = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: + embedding_names_per_rank = [] + for grouped_config in grouped_embedding_configs: + embedding_names_per_rank.extend(grouped_config.embedding_names()) + embedding_names.append(embedding_names_per_rank) + return embedding_names def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: embedding_shard_metadata = [] for grouped_config in self._grouped_embedding_configs: embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) - for grouped_config in self._score_grouped_embedding_configs: - embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) return embedding_shard_metadata - def id_list_feature_names(self) -> List[str]: - id_list_feature_names = [] + def feature_names(self) -> List[str]: + feature_names = [] for grouped_config in self._grouped_embedding_configs: - id_list_feature_names.extend(grouped_config.feature_names()) - return id_list_feature_names + feature_names.extend(grouped_config.feature_names()) + return feature_names - def id_score_list_feature_names(self) -> List[str]: - id_score_list_feature_names = [] - for grouped_config in self._score_grouped_embedding_configs: - id_score_list_feature_names.extend(grouped_config.feature_names()) - return id_score_list_feature_names + def embedding_tables(self) -> List[ShardedEmbeddingTable]: + embedding_tables = [] + for grouped_config in self._grouped_embedding_configs: + embedding_tables.extend(grouped_config.embedding_tables) + return embedding_tables - def _get_id_list_features_num(self) -> int: + def _get_num_features(self) -> int: return sum( group_config.num_features() for group_config in self._grouped_embedding_configs ) - def _get_id_score_list_features_num(self) -> int: - return sum( - group_config.num_features() - for group_config in self._score_grouped_embedding_configs - ) - - def _get_id_list_features_hash_sizes(self) -> List[int]: - id_list_feature_hash_sizes: List[int] = [] + def _get_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] for group_config in self._grouped_embedding_configs: - id_list_feature_hash_sizes.extend(group_config.feature_hash_sizes()) - return id_list_feature_hash_sizes + feature_hash_sizes.extend(group_config.feature_hash_sizes()) + return feature_hash_sizes - def _get_id_score_list_features_hash_sizes(self) -> List[int]: - id_score_list_feature_hash_sizes: List[int] = [] - for group_config in self._score_grouped_embedding_configs: - id_score_list_feature_hash_sizes.extend(group_config.feature_hash_sizes()) - return id_score_list_feature_hash_sizes - -class RwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): +class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll collective operation. @@ -218,11 +282,9 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. intra_pg (dist.ProcessGroup): ProcessGroup within single host group for AlltoAll communication. - num_id_list_features (int): total number of id list features. - num_id_score_list_features (int): total number of id score list features - id_list_feature_hash_sizes (List[int]): hash sizes of id list features. - id_score_list_feature_hash_sizes (List[int]): hash sizes of id score list - features. + num_features (int): total number of features. + feature_hash_sizes (List[int]): hash sizes of features. + feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size. device (Optional[torch.device]): device on which buffers will be allocated. is_sequence (bool): if this is for a sequence embedding. has_feature_processor (bool): existence of feature processor (ie. position @@ -233,105 +295,99 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): def __init__( self, pg: dist.ProcessGroup, - num_id_list_features: int, - num_id_score_list_features: int, - id_list_feature_hash_sizes: List[int], - id_score_list_feature_hash_sizes: List[int], + num_features: int, + feature_hash_sizes: List[int], + feature_total_num_buckets: Optional[List[int]] = None, device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False, need_pos: bool = False, - variable_batch_size: bool = False, + keep_original_indices: bool = False, ) -> None: super().__init__() self._world_size: int = pg.size() - self._num_id_list_features = num_id_list_features - self._num_id_score_list_features = num_id_score_list_features - id_list_feature_block_sizes = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in id_list_feature_hash_sizes - ] - id_score_list_feature_block_sizes = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in id_score_list_feature_hash_sizes - ] + self._num_features = num_features + + feature_block_sizes: List[int] = [] + + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if feature_total_num_buckets is not None: + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) + self.register_buffer( - "_id_list_feature_block_sizes_tensor", + "_feature_block_sizes_tensor", torch.tensor( - id_list_feature_block_sizes, + feature_block_sizes, device=device, - dtype=torch.int32, + dtype=torch.int64, ), + persistent=False, ) - self.register_buffer( - "_id_score_list_feature_block_sizes_tensor", - torch.tensor( - id_score_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), + self._has_multiple_blocks_per_shard: bool = ( + feature_total_num_buckets is not None ) - self._dist = SparseFeaturesAllToAll( + if self._has_multiple_blocks_per_shard: + self.register_buffer( + "_feature_total_num_blocks_tensor", + torch.tensor( + [feature_total_num_buckets], + device=device, + dtype=torch.int64, + ), + persistent=False, + ) + + self._dist = KJTAllToAll( pg=pg, - id_list_features_per_rank=self._world_size * [self._num_id_list_features], - id_score_list_features_per_rank=self._world_size - * [self._num_id_score_list_features], - device=device, - variable_batch_size=variable_batch_size, + splits=[self._num_features] * self._world_size, ) self._is_sequence = is_sequence self._has_feature_processor = has_feature_processor self._need_pos = need_pos self.unbucketize_permute_tensor: Optional[torch.Tensor] = None + self._keep_original_indices = keep_original_indices def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: + sparse_features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: """ Bucketizes sparse feature values into world size number of buckets and then performs AlltoAll operation. Args: - sparse_features (SparseFeatures): sparse features to bucketize and + sparse_features (KeyedJaggedTensor): sparse features to bucketize and redistribute. Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. + Awaitable[Awaitable[KeyedJaggedTensor]]: awaitable of awaitable of KeyedJaggedTensor. """ - if self._num_id_list_features > 0: - assert sparse_features.id_list_features is not None - ( - id_list_features, - self.unbucketize_permute_tensor, - ) = bucketize_kjt_before_all2all( - sparse_features.id_list_features, - num_buckets=self._world_size, - block_sizes=self._id_list_feature_block_sizes_tensor, - output_permute=self._is_sequence, - bucketize_pos=self._has_feature_processor, - ) - else: - id_list_features = None - - if self._num_id_score_list_features > 0: - assert sparse_features.id_score_list_features is not None - id_score_list_features, _ = bucketize_kjt_before_all2all( - sparse_features.id_score_list_features, - num_buckets=self._world_size, - block_sizes=self._id_score_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=self._need_pos, - ) - else: - id_score_list_features = None - - bucketized_sparse_features = SparseFeatures( - id_list_features=id_list_features, - id_score_list_features=id_score_list_features, + ( + bucketized_features, + self.unbucketize_permute_tensor, + ) = bucketize_kjt_before_all2all( + sparse_features, + num_buckets=self._world_size, + block_sizes=self._feature_block_sizes_tensor, + total_num_blocks=( + self._feature_total_num_blocks_tensor + if self._has_multiple_blocks_per_shard + else None + ), + output_permute=self._is_sequence, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), + keep_original_indices=self._keep_original_indices, ) - return self._dist(bucketized_sparse_features) + + return self._dist(bucketized_features) class RwPooledEmbeddingDist( @@ -348,18 +404,27 @@ class RwPooledEmbeddingDist( def __init__( self, pg: dist.ProcessGroup, + embedding_dims: List[int], qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__() - self._dist = PooledEmbeddingsReduceScatter( - pg, - codecs=qcomm_codecs_registry.get( + self._dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._pg = pg + self._qcomm_codecs_registry = qcomm_codecs_registry + self._codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None ) if qcomm_codecs_registry - else None, + else None ) + self._embedding_dims = embedding_dims def forward( self, @@ -371,20 +436,87 @@ def forward( Args: local_embs (torch.Tensor): pooled embeddings tensor to distribute. + sharding_ctx (Optional[EmbeddingShardingContext]): shared context from + KJTAllToAll operation. Returns: Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. """ + if self._dist is None: + self._create_output_dist_module(sharding_ctx) if sharding_ctx is None: - return self._dist(local_embs) + return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) + elif sharding_ctx.variable_batch_per_feature: + return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( + local_embs, + batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, + embedding_dims=self._embedding_dims, + ) + else: + return cast(PooledEmbeddingsReduceScatter, self._dist)( + local_embs, + input_splits=sharding_ctx.batch_size_per_rank, + ) + + def _create_output_dist_module( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + self._dist = VariableBatchPooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) else: - return self._dist(local_embs, input_splits=sharding_ctx.batch_size_per_rank) + self._dist = PooledEmbeddingsReduceScatter( + pg=self._pg, + codecs=self._codecs, + ) + + +class InferRwPooledEmbeddingDist( + BaseEmbeddingDist[NullShardingContext, List[torch.Tensor], torch.Tensor] +): + """ + Redistributes pooled embedding tensor in RW fashion with an AlltoOne operation. + + Args: + device (torch.device): device on which the tensors will be communicated to. + world_size (int): number of devices in the topology. + """ + + def __init__( + self, + device: torch.device, + world_size: int, + ) -> None: + super().__init__() + self._dist: EmbeddingsAllToOneReduce = EmbeddingsAllToOneReduce( + device=device, + world_size=world_size, + ) + + def forward( + self, + local_embs: List[torch.Tensor], + sharding_ctx: Optional[NullShardingContext] = None, + ) -> torch.Tensor: + """ + Performs AlltoOne operation on sequence embeddings tensor. + + Args: + local_embs (torch.Tensor): tensor of values to distribute. + + Returns: + Awaitable[torch.Tensor]: awaitable of sequence embeddings. + """ + + return self._dist(local_embs) class RwPooledEmbeddingSharding( BaseRwEmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -395,24 +527,19 @@ class RwPooledEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: - num_id_list_features = self._get_id_list_features_num() - num_id_score_list_features = self._get_id_score_list_features_num() - id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() - id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + num_features = self._get_num_features() + feature_hash_sizes = self._get_feature_hash_sizes() return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. pg=self._pg, - num_id_list_features=num_id_list_features, - num_id_score_list_features=num_id_score_list_features, - id_list_feature_hash_sizes=id_list_feature_hash_sizes, - id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, device=device if device is not None else self._device, is_sequence=False, has_feature_processor=self._has_feature_processor, need_pos=self._need_pos, - variable_batch_size=self._variable_batch_size, ) def create_lookup( @@ -423,10 +550,10 @@ def create_lookup( ) -> BaseEmbeddingLookup: return GroupedPooledEmbeddingsLookup( grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + sharding_type=ShardingType.ROW_WISE, ) def create_output_dist( @@ -438,4 +565,202 @@ def create_output_dist( # `Optional[ProcessGroup]`. self._pg, qcomm_codecs_registry=self.qcomm_codecs_registry, + embedding_dims=self.embedding_dims(), + ) + + +@overload +def convert_tensor(t: torch.Tensor, feature: KeyedJaggedTensor) -> torch.Tensor: ... +@overload +def convert_tensor(t: None, feature: KeyedJaggedTensor) -> None: ... + + +def convert_tensor( + t: Union[torch.Tensor, None], + feature: KeyedJaggedTensor, +) -> Union[torch.Tensor, None]: + # comparing to Optional[Tensor], this solution will keep output as Tensor when input is not None + if t is None: + return None + else: + return t.to( + device=feature.device(), + dtype=feature.values().dtype, + ) + + +class InferRwSparseFeaturesDist(BaseSparseFeaturesDist[InputDistOutputs]): + def __init__( + self, + world_size: int, + num_features: int, + feature_hash_sizes: List[int], + feature_total_num_buckets: Optional[List[int]] = None, + device: Optional[torch.device] = None, + is_sequence: bool = False, + has_feature_processor: bool = False, + need_pos: bool = False, + embedding_shard_metadata: Optional[List[List[int]]] = None, + keep_original_indices: bool = False, + ) -> None: + super().__init__() + logger.info( + f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" + f", keep_original_indices={keep_original_indices}" + ) + self._world_size: int = world_size + self._num_features = num_features + self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets + feature_block_sizes: List[int] = [] + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if ( + feature_total_num_buckets is not None + and embedding_shard_metadata is None + ): + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) + self.register_buffer( + "feature_block_sizes", + torch.tensor(feature_block_sizes), + ) + + self._dist = KJTOneToAll( + splits=self._world_size * [self._num_features], + world_size=world_size, + device=device, + ) + self._is_sequence = is_sequence + self._has_feature_processor = has_feature_processor + self._need_pos = need_pos + embedding_shard_metadata = embedding_shard_metadata or [] + for i, row_pos in enumerate(embedding_shard_metadata): + self.register_buffer(f"row_pos_{i}", torch.tensor(row_pos)) + self.embedding_shard_metadata_len: int = len(embedding_shard_metadata) + self._keep_original_indices = keep_original_indices + # pyre-ignore[8] + self.register_buffer( + "feature_total_num_buckets", + ( + torch.tensor(feature_total_num_buckets) + if feature_total_num_buckets + else None + ), + ) + self.forwarded: bool = False + + def get_block_bucketize_row_pos(self) -> Optional[List[torch.Tensor]]: + return [ + getattr(self, f"row_pos_{i}") + for i in range(self.embedding_shard_metadata_len) + ] or None + + def move_buffer(self, sparse_features: KeyedJaggedTensor) -> None: + # buffer should only be moved once, even if this method being executed multiple times. as later 'to' should return same tensor after first convert + self.feature_block_sizes = convert_tensor( + t=self.feature_block_sizes, feature=sparse_features + ) + self.feature_total_num_buckets = convert_tensor( + t=self.feature_total_num_buckets, feature=sparse_features + ) + for i in range(self.embedding_shard_metadata_len): + setattr( + self, + f"row_pos_{i}", + convert_tensor( + t=getattr(self, f"row_pos_{i}"), + feature=sparse_features, + ), + ) + + def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: + if not self.forwarded: + # after fx tracing, 'if' will be removed, and below line will actually be called multiple times. but it's ok as 'to' will return same tensor after first convert. + self.move_buffer(sparse_features) + self.forwarded = True + ( + bucketized_features, + unbucketize_permute_tensor, + bucket_mapping_tensor_opt, + ) = bucketize_kjt_inference( + sparse_features, + num_buckets=self._world_size, + block_sizes=self.feature_block_sizes, + total_num_buckets=self.feature_total_num_buckets, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), + block_bucketize_row_pos=self.get_block_bucketize_row_pos(), + is_sequence=self._is_sequence, + keep_original_indices=self._keep_original_indices, + ) + # KJTOneToAll + dist_kjt = self._dist.forward(bucketized_features) + return InputDistOutputs( + features=dist_kjt, + unbucketize_permute_tensor=( + unbucketize_permute_tensor if self._is_sequence else None + ), + bucket_mapping_tensor=( + bucket_mapping_tensor_opt if self._is_sequence else None + ), + bucketized_length=( + bucketized_features.lengths().view( + self._world_size * self._num_features, -1 + ) + if self._is_sequence + else None + ), + ) + + +class InferRwPooledEmbeddingSharding( + BaseRwEmbeddingSharding[ + NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor + ] +): + def create_input_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseSparseFeaturesDist[InputDistOutputs]: + num_features = self._get_num_features() + feature_hash_sizes = self._get_feature_hash_sizes() + + (embed_sharding, is_even_sharding) = get_embedding_shard_metadata( + self._grouped_embedding_configs_per_rank + ) + + return InferRwSparseFeaturesDist( + world_size=self._world_size, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + embedding_shard_metadata=embed_sharding if not is_even_sharding else None, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: + return InferGroupedPooledEmbeddingsLookup( + grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, + world_size=self._world_size, + fused_params=fused_params, + device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[NullShardingContext, List[torch.Tensor], torch.Tensor]: + assert device is not None + return InferRwPooledEmbeddingDist( + device=device, + world_size=self._world_size, ) diff --git a/torchrec/distributed/sharding/rw_tensor_pool_sharding.py b/torchrec/distributed/sharding/rw_tensor_pool_sharding.py new file mode 100644 index 000000000..63ea13fc4 --- /dev/null +++ b/torchrec/distributed/sharding/rw_tensor_pool_sharding.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from torch.distributed._shard.sharded_tensor import Shard, ShardMetadata + +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torchrec.distributed.dist_data import TensorValuesAllToAll +from torchrec.distributed.sharding.rw_pool_sharding import ( + InferRwObjectPoolInputDist, + RwObjectPoolIDsDist, +) +from torchrec.distributed.tensor_sharding import ( + InferObjectPoolSharding, + ObjectPoolRwShardingContext, + ObjectPoolSharding, + TensorPoolRwShardingContext, +) +from torchrec.distributed.types import LazyAwaitable, ShardingEnv +from torchrec.modules.object_pool_lookups import TensorPoolLookup + + +class RwTensorPoolValuesDist(torch.nn.Module): + """ + Module to distribute torch.Tensor to all ranks after local pool lookup + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + is_update (bool): Boolean indicating whether this is an update or not. + + Example: + dist = RwTensorPoolLookupValuesDist(pg) + # rank 0 + rank0_ctx = TensorPoolRwShardingContext( + num_ids_each_rank_to_send=2, + num_ids_each_rank_to_receive=3, + ) + rank0_values = torch.tensor([2,3,4,5]) + + # rank 1 + rank0_ctx = TensorPoolRwShardingContext( + num_ids_each_rank_to_send=3, + num_ids_each_rank_to_receive=2, + ) + rank1_values = torch.tensor([1,1,1,3,4]) + + rank0_out = dist(rank0_ctx, rank0_values).wait() + # rank0_out has values [2,3,1,1,1] + + rank1_out = dist(rank1_ctx, rank1_values).wait() + # rank1_out has values [4,5,3,4] + """ + + def __init__( + self, + pg: dist.ProcessGroup, + is_update: bool, + ) -> None: + super().__init__() + self._pg = pg + self._dist = TensorValuesAllToAll(pg=pg) + self._is_update = is_update + + def forward( + self, + ctx: TensorPoolRwShardingContext, + values: torch.Tensor, + ) -> LazyAwaitable[torch.Tensor]: + """ + Redistributes local tensor values after tensor pool lookup. + Will only permute values when updating. + + Args: + ctx (TensorPoolRwShardingContext): Context for RW sharding, containing + number of items to send and receive from each rank. + values (torch.Tensor): tensor to distribute. + + Returns: + LazyAwaitable[torch.Tensor]: Lazy awaitable of tensor + """ + + if self._is_update: + with torch.no_grad(): + assert hasattr(ctx, "unbucketize_permute") + bucketize_permute = torch.ops.fbgemm.invert_permute( + ctx.unbucketize_permute + ) + values = values[bucketize_permute] + + return self._dist( + input=values, + input_splits=ctx.num_ids_each_rank_to_send, + output_splits=ctx.num_ids_each_rank_to_receive, + ) + + +class TensorPoolRwSharding(ObjectPoolSharding): + def __init__( + self, + pool_size: int, + dim: int, + env: ShardingEnv, + device: torch.device, + ) -> None: + self._env = env + # pyre-ignore + self._pg: dist.ProcessGroup = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._pool_size = pool_size + self._dim = dim + + self._block_size: int = ( + pool_size + self._env.world_size - 1 + ) // self._env.world_size + + self.local_pool_size: int = ( + self._block_size + if self._env.rank != self._env.world_size - 1 + else pool_size - self._block_size * (self._env.world_size - 1) + ) + + self._block_size_t: torch.Tensor = torch.tensor( + [ + self._block_size, + ], + dtype=torch.long, + device=self._device, + ) + + def create_update_ids_dist( + self, + ) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist(self._pg, is_update=True) + + def create_update_values_dist( + self, + ) -> RwTensorPoolValuesDist: + """ + used in embedding A2A in update() + """ + return RwTensorPoolValuesDist(self._pg, is_update=True) + + def create_lookup_ids_dist(self) -> RwObjectPoolIDsDist: + return RwObjectPoolIDsDist(self._pg, is_update=False) + + def create_lookup_values_dist( + self, + ) -> RwTensorPoolValuesDist: + """ + used in embedding A2A in lookup() + """ + return RwTensorPoolValuesDist(self._pg, is_update=False) + + def get_sharded_states_to_register( + self, lookup: TensorPoolLookup + ) -> Iterable[Tuple[str, torch.Tensor]]: + for fqn, tensor in lookup.states_to_register(): + yield fqn, ShardedTensor._init_from_local_shards( + [ + Shard( + tensor=tensor, + metadata=ShardMetadata( + shard_offsets=[ + self._env.rank * self._block_size, + 0, + ], + shard_sizes=[ + tensor.shape[0], + tensor.shape[1], + ], + placement=f"rank:{self._env.rank}/{str(tensor.device)}", + ), + ) + ], + torch.Size([self._pool_size, tensor.shape[1]]), + process_group=self._env.process_group, + ) + + def create_context(self) -> ObjectPoolRwShardingContext: + return ObjectPoolRwShardingContext(block_size=self._block_size_t) + + +class InferRwTensorPoolOutputDist(torch.nn.Module): + """ + Collects local tensor values after tensor pool lookup + to one device during inference. + + Args: + env (ShardingEnv): Sharding environment + device (torch.device): device to collect onto + + Example: + device = torch.device("cpu") + dist = InferRwTensorPoolOutputDist(env, device) + lookups = [ + torch.Tensor([1,2,3], device="rank0:cuda:0"), + torch.Tensor([4,5,6], device="rank1:cuda:0"), + ] + vals = dist(lookups) + # tensors merged and on CPU + vals = torch.Tensor([1,2,3,4,5,6], device=device) + """ + + def __init__( + self, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + self._device: Optional[torch.device] = device + self._world_size: int = env.world_size + self._cat_dim = 0 + self._placeholder: torch.Tensor = torch.ones(1, device=device) + + def forward( + self, + lookups: List[torch.Tensor], + ) -> torch.Tensor: + """ + Merge lookup values tensor on different devices onto a single device and rank + + Args: + lookups (List[torch.Tensor]): List of tensors placed on possibly different + devices / ranks. + + Returns: + torch.Tensor: Merged tensor on the requested device + """ + torch._assert(len(lookups) == self._world_size, "lookups size not world size") + + non_cat_size = lookups[0].size(1 - self._cat_dim) + return torch.ops.fbgemm.merge_pooled_embeddings( + lookups, + non_cat_size, + # syntax for torchscript + self._placeholder.device, + self._cat_dim, + ) + + +class InferRwTensorPoolSharding(InferObjectPoolSharding): + def __init__( + self, + pool_size: int, + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__(pool_size, env, device) + + def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist: + return InferRwObjectPoolInputDist( + self._env, device=self._device, block_size=self._block_size_t + ) + + def create_lookup_values_dist( + self, + ) -> InferRwTensorPoolOutputDist: + return InferRwTensorPoolOutputDist(env=self._env, device=self._device) diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index 8a7b96e9c..ebffa5490 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -5,17 +5,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field +# pyre-strict + +from dataclasses import dataclass from typing import List, Optional import torch -import torch.distributed as dist # noqa from torchrec.distributed.embedding_sharding import EmbeddingShardingContext +from torchrec.distributed.embedding_types import KJTList from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable -@dataclass class SequenceShardingContext(EmbeddingShardingContext): """ Stores KJTAllToAll context and reuses it in SequenceEmbeddingsAllToAll. @@ -32,24 +33,52 @@ class SequenceShardingContext(EmbeddingShardingContext): input dist. """ - features_before_input_dist: Optional[KeyedJaggedTensor] = None - input_splits: List[int] = field(default_factory=list) - output_splits: List[int] = field(default_factory=list) - sparse_features_recat: Optional[torch.Tensor] = None - unbucketize_permute_tensor: Optional[torch.Tensor] = None - lengths_after_input_dist: Optional[torch.Tensor] = None + # Torch Dynamo does not support default_factory=list: + # https://github.com/pytorch/pytorch/issues/120108 + # TODO(ivankobzarev): Make this a dataclass once supported + + def __init__( + self, + # Fields of EmbeddingShardingContext + batch_size_per_rank: Optional[List[int]] = None, + batch_size_per_rank_per_feature: Optional[List[List[int]]] = None, + batch_size_per_feature_pre_a2a: Optional[List[int]] = None, + variable_batch_per_feature: bool = False, + # Fields of SequenceShardingContext + features_before_input_dist: Optional[KeyedJaggedTensor] = None, + input_splits: Optional[List[int]] = None, + output_splits: Optional[List[int]] = None, + sparse_features_recat: Optional[torch.Tensor] = None, + unbucketize_permute_tensor: Optional[torch.Tensor] = None, + lengths_after_input_dist: Optional[torch.Tensor] = None, + ) -> None: + super().__init__( + batch_size_per_rank, + batch_size_per_rank_per_feature, + batch_size_per_feature_pre_a2a, + variable_batch_per_feature, + ) + self.features_before_input_dist: Optional[KeyedJaggedTensor] = ( + features_before_input_dist + ) + self.input_splits: List[int] = input_splits if input_splits is not None else [] + self.output_splits: List[int] = ( + output_splits if output_splits is not None else [] + ) + self.sparse_features_recat: Optional[torch.Tensor] = sparse_features_recat + self.unbucketize_permute_tensor: Optional[torch.Tensor] = ( + unbucketize_permute_tensor + ) + self.lengths_after_input_dist: Optional[torch.Tensor] = lengths_after_input_dist - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: if self.features_before_input_dist is not None: self.features_before_input_dist.record_stream(stream) if self.sparse_features_recat is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.sparse_features_recat.record_stream(stream) if self.unbucketize_permute_tensor is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.unbucketize_permute_tensor.record_stream(stream) if self.lengths_after_input_dist is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.lengths_after_input_dist.record_stream(stream) @@ -59,14 +88,26 @@ class InferSequenceShardingContext(Multistreamable): Stores inference context and reuses it in sequence embedding output_dist or result return. Attributes: - features (Optional[List[KeyedJaggedTensor]]): stores the original - shards of KJT after input dist. + features KJTList: stores the shards of KJT after input dist. + features_before_input_dist KJT: stores the original input KJT (before input dist). + unbucketize_permute_tensor Optional[torch.Tensor]: stores unbucketize tensor, only for RowWise sharding. """ - features: Optional[List[KeyedJaggedTensor]] = None + features: KJTList + features_before_input_dist: Optional[KeyedJaggedTensor] = None + unbucketize_permute_tensor: Optional[torch.Tensor] = None + bucket_mapping_tensor: Optional[torch.Tensor] = None + bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - if self.features is not None: - # pyre-ignore [16] - for feature in self.features: - feature.record_stream(stream) + def record_stream(self, stream: torch.Stream) -> None: + for feature in self.features: + feature.record_stream(stream) + if self.features_before_input_dist is not None: + self.features_before_input_dist.record_stream(stream) + if self.unbucketize_permute_tensor is not None: + self.unbucketize_permute_tensor.record_stream(stream) + if self.bucket_mapping_tensor is not None: + self.bucket_mapping_tensor.record_stream(stream) + if self.bucketized_length is not None: + self.bucketized_length.record_stream(stream) diff --git a/torchrec/distributed/sharding/tw_sequence_sharding.py b/torchrec/distributed/sharding/tw_sequence_sharding.py index 32df26f26..77b2d2321 100644 --- a/torchrec/distributed/sharding/tw_sequence_sharding.py +++ b/torchrec/distributed/sharding/tw_sequence_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Dict, List, Optional import torch @@ -24,8 +26,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, - SparseFeatures, - SparseFeaturesList, + InputDistOutputs, ) from torchrec.distributed.sharding.sequence_sharding import ( InferSequenceShardingContext, @@ -37,6 +38,7 @@ TwSparseFeaturesDist, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class TwSequenceEmbeddingDist( @@ -64,11 +66,13 @@ def __init__( pg, features_per_rank, device, - codecs=qcomm_codecs_registry.get( - CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, + codecs=( + qcomm_codecs_registry.get( + CommOp.SEQUENCE_EMBEDDINGS_ALL_TO_ALL.name, None + ) + if qcomm_codecs_registry + else None + ), ) def forward( @@ -102,7 +106,7 @@ def forward( class TwSequenceEmbeddingSharding( BaseTwEmbeddingSharding[ - SequenceShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -113,15 +117,12 @@ class TwSequenceEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: return TwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - variable_batch_size=self._variable_batch_size, + self.features_per_rank(), ) def create_lookup( @@ -144,7 +145,7 @@ def create_output_dist( assert self._pg is not None return TwSequenceEmbeddingDist( self._pg, - self.id_list_features_per_rank(), + self.features_per_rank(), device if device is not None else self._device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) @@ -176,7 +177,7 @@ def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, - ) -> Awaitable[List[torch.Tensor]]: + ) -> List[torch.Tensor]: """ Performs AlltoOne operation on sequence embeddings tensor. @@ -189,13 +190,13 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of sequence embeddings. """ - return self._dist.forward(local_embs) + return self._dist(local_embs) class InferTwSequenceEmbeddingSharding( BaseTwEmbeddingSharding[ InferSequenceShardingContext, - SparseFeaturesList, + InputDistOutputs, List[torch.Tensor], List[torch.Tensor], ] @@ -207,11 +208,11 @@ class InferTwSequenceEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeaturesList]: + ) -> BaseSparseFeaturesDist[InputDistOutputs]: return InferTwSparseFeaturesDist( - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - self._world_size, + features_per_rank=self.features_per_rank(), + world_size=self._world_size, + device=device, ) def create_lookup( @@ -219,11 +220,12 @@ def create_lookup( device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]]: + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: return InferGroupedEmbeddingsLookup( grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, world_size=self._world_size, fused_params=fused_params, + device=device, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 9913c602e..d5506f0a5 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -5,11 +5,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, List, Optional, TypeVar +# pyre-strict + +from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union import torch import torch.distributed as dist -from torchrec.distributed.dist_data import EmbeddingsAllToOne, PooledEmbeddingsAllToAll +from torch.distributed._tensor.placement_types import Replicate +from torchrec.distributed.dist_data import ( + EmbeddingsAllToOne, + KJTAllToAll, + KJTOneToAll, + PooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, +) from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, InferGroupedPooledEmbeddingsLookup, @@ -22,27 +31,28 @@ EmbeddingShardingContext, EmbeddingShardingInfo, group_tables, - NullShardingContext, - SparseFeaturesAllToAll, - SparseFeaturesOneToAll, ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, + InputDistOutputs, ShardedEmbeddingTable, - SparseFeatures, - SparseFeaturesList, ) from torchrec.distributed.types import ( Awaitable, CommOp, - NoWait, + NullShardingContext, QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, + ShardingType, ShardMetadata, ) +from torchrec.distributed.utils import none_throws +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable @@ -63,32 +73,32 @@ def __init__( env: ShardingEnv, device: Optional[torch.device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._device = device - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._env: ShardingEnv = env + self._device: Optional[torch.device] = device + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank + sharded_tables_per_rank = self._shard(sharding_infos) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._score_grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - ( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ) = group_tables(sharded_tables_per_rank) - self._grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._grouped_embedding_configs_per_rank[self._rank] - self._score_grouped_embedding_configs: List[ - GroupedEmbeddingConfig - ] = self._score_grouped_embedding_configs_per_rank[self._rank] - self._variable_batch_size = variable_batch_size + + self._sharded_tables_per_rank: List[List[ShardedEmbeddingTable]] = ( + sharded_tables_per_rank + ) + + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + self._grouped_embedding_configs_per_rank[self._rank] + ) def _shard( self, @@ -96,24 +106,54 @@ def _shard( ) -> List[List[ShardedEmbeddingTable]]: world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards # construct the global sharded_tensor_metadata + global_metadata = ShardedTensorMetadata( shards_metadata=shards, size=torch.Size( [ - info.embedding_config.num_embeddings, + ( + info.embedding_config.num_embeddings_post_pruning + if info.embedding_config.num_embeddings_post_pruning + is not None + else info.embedding_config.num_embeddings + ), info.embedding_config.embedding_dim, ] ), ) - # pyre-fixme [16] - tables_per_rank[info.param_sharding.ranks[0]].append( + dtensor_metadata = None + if self._env.output_dtensor: + dtensor_metadata = DTensorMetadata( + mesh=( + self._env.device_mesh["replicate"] # pyre-ignore[16] + if self._is_2D_parallel + else self._env.device_mesh + ), + placements=(Replicate(),), + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + + rank = ( + # pyre-ignore [16] + self._env.remap_rank( + info.param_sharding.ranks[0], # pyre-ignore[16] + ShardingType.TABLE_WISE, + ) + if self._is_2D_parallel + else info.param_sharding.ranks[0] + ) + tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, embedding_dim=info.embedding_config.embedding_dim, @@ -124,190 +164,142 @@ def _shard( pooling=info.embedding_config.pooling, is_weighted=info.embedding_config.is_weighted, has_feature_processor=info.embedding_config.has_feature_processor, - local_rows=info.embedding_config.num_embeddings, + local_rows=( + none_throws(info.embedding_config.num_embeddings_post_pruning) + if info.embedding_config.num_embeddings_post_pruning is not None + else info.embedding_config.num_embeddings + ), local_cols=info.embedding_config.embedding_dim, compute_kernel=EmbeddingComputeKernel( info.param_sharding.compute_kernel ), local_metadata=shards[0], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, + num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning, ) ) return tables_per_rank def _dim_sum_per_rank(self) -> List[int]: dim_sum_per_rank = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: dim_sum = 0 for grouped_config in grouped_embedding_configs: dim_sum += grouped_config.dim_sum() - for grouped_config in score_grouped_embedding_configs: - dim_sum += grouped_config.dim_sum() dim_sum_per_rank.append(dim_sum) return dim_sum_per_rank + def _emb_dim_per_rank_per_feature(self) -> List[List[int]]: + emb_dim_per_rank_per_feature = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: + emb_dim_per_feature = [] + for grouped_config in grouped_embedding_configs: + emb_dim_per_feature += grouped_config.embedding_dims() + emb_dim_per_rank_per_feature.append(emb_dim_per_feature) + return emb_dim_per_rank_per_feature + def embedding_dims(self) -> List[int]: embedding_dims = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: for grouped_config in grouped_embedding_configs: embedding_dims.extend(grouped_config.embedding_dims()) - for grouped_config in score_grouped_embedding_configs: - embedding_dims.extend(grouped_config.embedding_dims()) return embedding_dims def embedding_names(self) -> List[str]: embedding_names = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: for grouped_config in grouped_embedding_configs: embedding_names.extend(grouped_config.embedding_names()) - for grouped_config in score_grouped_embedding_configs: - embedding_names.extend(grouped_config.embedding_names()) return embedding_names def embedding_names_per_rank(self) -> List[List[str]]: embedding_names = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: embedding_names_per_rank = [] for grouped_config in grouped_embedding_configs: embedding_names_per_rank.extend(grouped_config.embedding_names()) - for grouped_config in score_grouped_embedding_configs: - embedding_names_per_rank.extend(grouped_config.embedding_names()) embedding_names.append(embedding_names_per_rank) return embedding_names def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: embedding_shard_metadata = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: for grouped_config in grouped_embedding_configs: embedding_shard_metadata.extend( grouped_config.embedding_shard_metadata() ) - for grouped_config in score_grouped_embedding_configs: - embedding_shard_metadata.extend( - grouped_config.embedding_shard_metadata() - ) return embedding_shard_metadata - def id_list_feature_names(self) -> List[str]: - id_list_feature_names = [] + def feature_names(self) -> List[str]: + feature_names = [] for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: for grouped_config in grouped_embedding_configs: - id_list_feature_names.extend(grouped_config.feature_names()) - return id_list_feature_names - - def id_score_list_feature_names(self) -> List[str]: - id_score_list_feature_names = [] - for ( - score_grouped_embedding_configs - ) in self._score_grouped_embedding_configs_per_rank: - for grouped_config in score_grouped_embedding_configs: - id_score_list_feature_names.extend(grouped_config.feature_names()) - return id_score_list_feature_names - - def id_list_feature_names_per_rank(self) -> List[List[str]]: - id_list_feature_names = [] + feature_names.extend(grouped_config.feature_names()) + return feature_names + + def embedding_tables(self) -> List[ShardedEmbeddingTable]: + embedding_tables = [] for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: - id_list_feature_names_per_rank = [] for grouped_config in grouped_embedding_configs: - id_list_feature_names_per_rank.extend(grouped_config.feature_names()) - id_list_feature_names.append(id_list_feature_names_per_rank) - return id_list_feature_names - - def id_score_list_feature_names_per_rank(self) -> List[List[str]]: - id_score_list_feature_names = [] - for ( - score_grouped_embedding_configs - ) in self._score_grouped_embedding_configs_per_rank: - id_score_list_feature_names_per_rank = [] - for grouped_config in score_grouped_embedding_configs: - id_score_list_feature_names_per_rank.extend( - grouped_config.feature_names() - ) - id_score_list_feature_names.append(id_score_list_feature_names_per_rank) - return id_score_list_feature_names + embedding_tables.extend(grouped_config.embedding_tables) + return embedding_tables - def id_list_features_per_rank(self) -> List[int]: - id_list_features_per_rank = [] + def feature_names_per_rank(self) -> List[List[str]]: + feature_names = [] for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: - num_features = 0 + feature_names_per_rank = [] for grouped_config in grouped_embedding_configs: - num_features += grouped_config.num_features() - id_list_features_per_rank.append(num_features) - return id_list_features_per_rank - - def id_score_list_features_per_rank(self) -> List[int]: - id_score_list_features_per_rank = [] - for ( - score_grouped_embedding_configs - ) in self._score_grouped_embedding_configs_per_rank: + feature_names_per_rank.extend(grouped_config.feature_names()) + feature_names.append(feature_names_per_rank) + return feature_names + + def features_per_rank(self) -> List[int]: + features_per_rank = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: num_features = 0 - for grouped_config in score_grouped_embedding_configs: + for grouped_config in grouped_embedding_configs: num_features += grouped_config.num_features() - id_score_list_features_per_rank.append(num_features) - return id_score_list_features_per_rank + features_per_rank.append(num_features) + return features_per_rank -class TwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): +class TwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ Redistributes sparse features with an AlltoAll collective operation for table wise sharding. Args: pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - id_list_features_per_rank (List[int]): number of id list features to send to - each rank. - id_score_list_features_per_rank (List[int]): number of id score list features to - send to each rank. - device (Optional[torch.device]): device on which buffers will be allocated. + features_per_rank (List[int]): number of features to send to each rank. """ def __init__( self, pg: dist.ProcessGroup, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - device: Optional[torch.device] = None, - variable_batch_size: bool = False, + features_per_rank: List[int], ) -> None: super().__init__() - self._dist = SparseFeaturesAllToAll( + self._dist = KJTAllToAll( pg=pg, - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - device=device, - variable_batch_size=variable_batch_size, + splits=features_per_rank, ) def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: + sparse_features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: """ Performs AlltoAll operation on sparse features. Args: - sparse_features (SparseFeatures): sparse features to redistribute. + sparse_features (KeyedJaggedTensor): sparse features to redistribute. Returns: - Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures. + Awaitable[Awaitable[KeyedJaggedTensor]]: awaitable of awaitable of KeyedJaggedTensor. """ return self._dist(sparse_features) @@ -324,29 +316,36 @@ class TwPooledEmbeddingDist( pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. dim_sum_per_rank (List[int]): number of features (sum of dimensions) of the embedding in each rank. + emb_dim_per_rank_per_feature (List[List[int]]): embedding dimension per rank per + feature, used for variable batch per feature. device (Optional[torch.device]): device on which buffers will be allocated. + callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]): + qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]): """ def __init__( self, pg: dist.ProcessGroup, dim_sum_per_rank: List[int], + emb_dim_per_rank_per_feature: List[List[int]], device: Optional[torch.device] = None, callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__() - self._dist = PooledEmbeddingsAllToAll( - pg=pg, - dim_sum_per_rank=dim_sum_per_rank, - device=device, - callbacks=callbacks, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) + self._pg = pg + self._dim_sum_per_rank = dim_sum_per_rank + self._device = device + self._callbacks = callbacks + self._codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get(CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None) if qcomm_codecs_registry - else None, + else None ) + self._emb_dim_per_rank_per_feature = emb_dim_per_rank_per_feature + self._dist: Optional[ + Union[PooledEmbeddingsAllToAll, VariableBatchPooledEmbeddingsAllToAll] + ] = None def forward( self, @@ -358,21 +357,53 @@ def forward( Args: local_embs (torch.Tensor): tensor of values to distribute. + sharding_ctx (Optional[EmbeddingShardingContext]): shared context from + KJTAllToAll operation. Returns: Awaitable[torch.Tensor]: awaitable of pooled embeddings. """ + if self._dist is None: + self._create_output_dist_module(sharding_ctx) + if sharding_ctx is None: - return self._dist(local_embs) + return cast(PooledEmbeddingsAllToAll, self._dist)(local_embs) + elif sharding_ctx.variable_batch_per_feature: + return cast(VariableBatchPooledEmbeddingsAllToAll, self._dist)( + local_embs, + batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, + batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a, + ) else: - return self._dist( - local_embs, batch_size_per_rank=sharding_ctx.batch_size_per_rank + return cast(PooledEmbeddingsAllToAll, self._dist)( + local_embs, + batch_size_per_rank=sharding_ctx.batch_size_per_rank, + ) + + def _create_output_dist_module( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + self._dist = VariableBatchPooledEmbeddingsAllToAll( + pg=self._pg, + emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature, + device=self._device, + callbacks=None, + codecs=self._codecs, + ) + else: + self._dist = PooledEmbeddingsAllToAll( + pg=self._pg, + dim_sum_per_rank=self._dim_sum_per_rank, + device=self._device, + callbacks=self._callbacks, + codecs=self._codecs, ) class TwPooledEmbeddingSharding( BaseTwEmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -383,14 +414,11 @@ class TwPooledEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: assert self._pg is not None return TwSparseFeaturesDist( self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - self._variable_batch_size, + self.features_per_rank(), ) def create_lookup( @@ -401,7 +429,6 @@ def create_lookup( ) -> BaseEmbeddingLookup: return GroupedPooledEmbeddingsLookup( grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, @@ -413,53 +440,51 @@ def create_output_dist( ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: assert self._pg is not None return TwPooledEmbeddingDist( - self._pg, - self._dim_sum_per_rank(), - device if device is not None else self._device, + pg=self._pg, + dim_sum_per_rank=self._dim_sum_per_rank(), + emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature(), + device=device if device is not None else self._device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) -class InferTwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeaturesList]): +class InferTwSparseFeaturesDist(BaseSparseFeaturesDist[InputDistOutputs]): """ Redistributes sparse features to all devices for inference. Args: - id_list_features_per_rank (List[int]): number of id list features to send - to each rank. - id_score_list_features_per_rank (List[int]): number of id score list features - to send to each rank. + features_per_rank (List[int]): number of features to send to each rank. world_size (int): number of devices in the topology. + fused_params (Dict[str, Any]): fused parameters of the model. """ def __init__( self, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], + features_per_rank: List[int], world_size: int, + device: Optional[torch.device] = None, ) -> None: super().__init__() - self._dist: SparseFeaturesOneToAll = SparseFeaturesOneToAll( - id_list_features_per_rank, - id_score_list_features_per_rank, - world_size, + self._dist = KJTOneToAll( + splits=features_per_rank, + world_size=world_size, + device=device, ) def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeaturesList]]: + sparse_features: KeyedJaggedTensor, + ) -> InputDistOutputs: """ Performs OnetoAll operation on sparse features. Args: - sparse_features (SparseFeatures): sparse features to redistribute. + sparse_features (KeyedJaggedTensor): sparse features to redistribute. Returns: - Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures. + Awaitable[Awaitable[KeyedJaggedTensor]]: awaitable of awaitable of KeyedJaggedTensor. """ - - return NoWait(self._dist.forward(sparse_features)) + return InputDistOutputs(features=self._dist.forward(sparse_features)) class InferTwPooledEmbeddingDist( @@ -485,7 +510,7 @@ def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[NullShardingContext] = None, - ) -> Awaitable[torch.Tensor]: + ) -> torch.Tensor: """ Performs AlltoOne operation on pooled embedding tensors. @@ -494,15 +519,15 @@ def forward( `len(local_embs) == world_size`. Returns: - Awaitable[torch.Tensor]: awaitable of merged pooled embedding tensor. + torch.Tensor: merged pooled embedding tensor. """ - return self._dist.forward(local_embs) + return self._dist(local_embs) class InferTwEmbeddingSharding( BaseTwEmbeddingSharding[ - NullShardingContext, SparseFeaturesList, List[torch.Tensor], torch.Tensor + NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor ] ): """ @@ -510,12 +535,13 @@ class InferTwEmbeddingSharding( """ def create_input_dist( - self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeaturesList]: + self, + device: Optional[torch.device] = None, + ) -> BaseSparseFeaturesDist[InputDistOutputs]: return InferTwSparseFeaturesDist( - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - self._world_size, + features_per_rank=self.features_per_rank(), + world_size=self._world_size, + device=device, ) def create_lookup( @@ -523,12 +549,12 @@ def create_lookup( device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup[SparseFeaturesList, List[torch.Tensor]]: + ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: return InferGroupedPooledEmbeddingsLookup( grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, - grouped_score_configs_per_rank=self._score_grouped_embedding_configs_per_rank, world_size=self._world_size, fused_params=fused_params, + device=device, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/twcw_sharding.py b/torchrec/distributed/sharding/twcw_sharding.py index 96470c0ae..16a0e9d53 100644 --- a/torchrec/distributed/sharding/twcw_sharding.py +++ b/torchrec/distributed/sharding/twcw_sharding.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Dict, List, Optional import torch diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 812f2c9b2..16df99c9f 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -5,16 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import itertools import math -from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar +from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar, Union import torch import torch.distributed as dist -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.distributed_c10d import get_process_group_ranks +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( + KJTAllToAll, PooledEmbeddingsAllToAll, PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, ) from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup from torchrec.distributed.embedding_sharding import ( @@ -26,14 +37,13 @@ EmbeddingShardingContext, EmbeddingShardingInfo, group_tables, - SparseFeaturesAllToAll, ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, - SparseFeatures, ) from torchrec.distributed.types import ( Awaitable, @@ -41,8 +51,11 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, + ShardingType, ShardMetadata, ) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable C = TypeVar("C", bound=Multistreamable) @@ -63,16 +76,29 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device self._need_pos = need_pos - intra_pg, cross_pg = intra_and_cross_node_pg(device) + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-fixme[6] + self._env, + device=device, + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -80,39 +106,24 @@ def __init__( ) sharded_tables_per_rank = self._shard(sharding_infos) - self._grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._score_grouped_embedding_configs_per_rank: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._grouped_embedding_configs_per_node: List[ - List[GroupedEmbeddingConfig] - ] = [] - self._score_grouped_embedding_configs_per_node: List[ - List[GroupedEmbeddingConfig] - ] = [] - ( - self._grouped_embedding_configs_per_rank, - self._score_grouped_embedding_configs_per_rank, - ) = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_node: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) self._grouped_embedding_configs_per_node = [ self._grouped_embedding_configs_per_rank[rank] for rank in range(self._world_size) if rank % self._local_size == 0 ] - self._score_grouped_embedding_configs_per_node = [ - self._score_grouped_embedding_configs_per_rank[rank] - for rank in range(self._world_size) - if rank % self._local_size == 0 - ] self._has_feature_processor: bool = False - for group_config in self._score_grouped_embedding_configs_per_node[ + for group_config in self._grouped_embedding_configs_per_rank[ self._rank // self._local_size ]: if group_config.has_feature_processor: self._has_feature_processor = True - self._variable_batch_size = variable_batch_size def _shard( self, @@ -121,11 +132,23 @@ def _shard( world_size = self._world_size local_size = self._local_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + peer_group = ( + # pyre-ignore [6] + get_process_group_ranks(self._pg) + if self._is_2D_parallel + else None + ) for info in sharding_infos: - # pyre-ignore [16] - table_node = info.param_sharding.ranks[0] // local_size + # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme + rank = ( + # pyre-ignore [16] + peer_group.index(info.param_sharding.ranks[0]) + if peer_group is not None + else info.param_sharding.ranks[0] + ) + table_node = rank // local_size # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards @@ -140,6 +163,20 @@ def _shard( ), ) + dtensor_metadata = None + if self._env.output_dtensor: + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ), + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + for rank in range( table_node * local_size, (table_node + 1) * local_size, @@ -163,6 +200,7 @@ def _shard( ), local_metadata=shards[rank_idx], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, @@ -173,26 +211,16 @@ def _shard( def embedding_dims(self) -> List[int]: embedding_dims = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_node, - self._score_grouped_embedding_configs_per_node, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: for grouped_config in grouped_embedding_configs: embedding_dims.extend(grouped_config.embedding_dims()) - for grouped_config in score_grouped_embedding_configs: - embedding_dims.extend(grouped_config.embedding_dims()) return embedding_dims def embedding_names(self) -> List[str]: embedding_names = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_node, - self._score_grouped_embedding_configs_per_node, - ): + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: for grouped_config in grouped_embedding_configs: embedding_names.extend(grouped_config.embedding_names()) - for grouped_config in score_grouped_embedding_configs: - embedding_names.extend(grouped_config.embedding_names()) return embedding_names def embedding_names_per_rank(self) -> List[List[str]]: @@ -203,52 +231,39 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: for grouped_config in self._grouped_embedding_configs_per_node: for config in grouped_config: embedding_shard_metadata.extend(config.embedding_shard_metadata()) - for grouped_config in self._score_grouped_embedding_configs_per_node: - for config in grouped_config: - embedding_shard_metadata.extend(config.embedding_shard_metadata()) return embedding_shard_metadata - def id_list_feature_names(self) -> List[str]: - id_list_feature_names = [] + def feature_names(self) -> List[str]: + feature_names = [] for grouped_config in self._grouped_embedding_configs_per_node: for config in grouped_config: - id_list_feature_names.extend(config.feature_names()) - return id_list_feature_names - - def id_score_list_feature_names(self) -> List[str]: - id_score_list_feature_names = [] - for grouped_config in self._score_grouped_embedding_configs_per_node: - for config in grouped_config: - id_score_list_feature_names.extend(config.feature_names()) - return id_score_list_feature_names + feature_names.extend(config.feature_names()) + return feature_names - def _get_id_list_features_hash_sizes(self) -> List[int]: - id_list_feature_hash_sizes: List[int] = [] + def _get_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] for grouped_config in self._grouped_embedding_configs_per_node: for config in grouped_config: - id_list_feature_hash_sizes.extend(config.feature_hash_sizes()) - return id_list_feature_hash_sizes - - def _get_id_score_list_features_hash_sizes(self) -> List[int]: - id_score_list_feature_hash_sizes: List[int] = [] - for grouped_config in self._score_grouped_embedding_configs_per_node: - for config in grouped_config: - id_score_list_feature_hash_sizes.extend(config.feature_hash_sizes()) - return id_score_list_feature_hash_sizes + feature_hash_sizes.extend(config.feature_hash_sizes()) + return feature_hash_sizes def _dim_sum_per_node(self) -> List[int]: - dim_sum_per_rank = [] - for grouped_embedding_configs, score_grouped_embedding_configs in zip( - self._grouped_embedding_configs_per_node, - self._score_grouped_embedding_configs_per_node, - ): + dim_sum_per_node = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: dim_sum = 0 for grouped_config in grouped_embedding_configs: dim_sum += grouped_config.dim_sum() - for grouped_config in score_grouped_embedding_configs: - dim_sum += grouped_config.dim_sum() - dim_sum_per_rank.append(dim_sum) - return dim_sum_per_rank + dim_sum_per_node.append(dim_sum) + return dim_sum_per_node + + def _emb_dim_per_node_per_feature(self) -> List[List[int]]: + emb_dim_per_node_per_feature = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + emb_dim_per_feature = [] + for grouped_config in grouped_embedding_configs: + emb_dim_per_feature += grouped_config.embedding_dims() + emb_dim_per_node_per_feature.append(emb_dim_per_feature) + return emb_dim_per_node_per_feature def _features_per_rank( self, group: List[List[GroupedEmbeddingConfig]] @@ -262,7 +277,7 @@ def _features_per_rank( return features_per_rank -class TwRwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): +class TwRwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ Bucketizes sparse features in TWRW fashion and then redistributes with an AlltoAll collective operation. @@ -314,14 +329,11 @@ def __init__( self, pg: dist.ProcessGroup, local_size: int, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - id_list_feature_hash_sizes: List[int], - id_score_list_feature_hash_sizes: List[int], + features_per_rank: List[int], + feature_hash_sizes: List[int], device: Optional[torch.device] = None, has_feature_processor: bool = False, need_pos: bool = False, - variable_batch_size: bool = False, ) -> None: super().__init__() assert pg.size() % local_size == 0, "currently group granularity must be node" @@ -329,108 +341,70 @@ def __init__( self._world_size: int = pg.size() self._local_size: int = local_size self._num_cross_nodes: int = self._world_size // self._local_size - id_list_feature_block_sizes = [ - math.ceil(hash_size / self._local_size) - for hash_size in id_list_feature_hash_sizes - ] - id_score_list_feature_block_sizes = [ - math.ceil(hash_size / self._local_size) - for hash_size in id_score_list_feature_hash_sizes + feature_block_sizes = [ + math.ceil(hash_size / self._local_size) for hash_size in feature_hash_sizes ] - self._id_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( - id_list_features_per_rank - ) - self._id_score_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( - id_score_list_features_per_rank - ) - self.register_buffer( - "_id_list_feature_block_sizes_tensor", - torch.tensor( - id_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), + self._sf_staggered_shuffle: List[int] = self._staggered_shuffle( + features_per_rank ) self.register_buffer( - "_id_score_list_feature_block_sizes_tensor", + "_feature_block_sizes_tensor", torch.tensor( - id_score_list_feature_block_sizes, + feature_block_sizes, device=device, dtype=torch.int32, ), ) self.register_buffer( - "_id_list_sf_staggered_shuffle_tensor", + "_sf_staggered_shuffle_tensor", torch.tensor( - self._id_list_sf_staggered_shuffle, + self._sf_staggered_shuffle, device=device, dtype=torch.int32, ), ) - self.register_buffer( - "_id_score_list_sf_staggered_shuffle_tensor", - torch.tensor( - self._id_score_list_sf_staggered_shuffle, - device=device, - dtype=torch.int32, - ), - ) - self._dist = SparseFeaturesAllToAll( + self._dist = KJTAllToAll( pg=pg, - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - device=device, + splits=features_per_rank, stagger=self._num_cross_nodes, - variable_batch_size=variable_batch_size, ) self._has_feature_processor = has_feature_processor self._need_pos = need_pos def forward( self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: + sparse_features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: """ Bucketizes sparse feature values into local world size number of buckets, performs staggered shuffle on the sparse features, and then performs AlltoAll operation. Args: - sparse_features (SparseFeatures): sparse features to bucketize and + sparse_features (KeyedJaggedTensor): sparse features to bucketize and redistribute. Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. + Awaitable[KeyedJaggedTensor]: awaitable of KeyedJaggedTensor. """ - bucketized_sparse_features = SparseFeatures( - id_list_features=bucketize_kjt_before_all2all( - sparse_features.id_list_features, - num_buckets=self._local_size, - block_sizes=self._id_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=self._has_feature_processor, - )[0].permute( - self._id_list_sf_staggered_shuffle, - self._id_list_sf_staggered_shuffle_tensor, - ) - if sparse_features.id_list_features is not None - else None, - id_score_list_features=bucketize_kjt_before_all2all( - sparse_features.id_score_list_features, - num_buckets=self._local_size, - block_sizes=self._id_score_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=self._need_pos, - )[0].permute( - self._id_score_list_sf_staggered_shuffle, - self._id_score_list_sf_staggered_shuffle_tensor, - ) - if sparse_features.id_score_list_features is not None - else None, + bucketized_features = bucketize_kjt_before_all2all( + sparse_features, + num_buckets=self._local_size, + block_sizes=self._feature_block_sizes_tensor, + output_permute=False, + bucketize_pos=( + self._has_feature_processor + if sparse_features.weights_or_none() is None + else self._need_pos + ), + )[0].permute( + self._sf_staggered_shuffle, + self._sf_staggered_shuffle_tensor, ) - return self._dist(bucketized_sparse_features) + + return self._dist(bucketized_features) def _staggered_shuffle(self, features_per_rank: List[int]) -> List[int]: """ @@ -468,7 +442,9 @@ class TwRwPooledEmbeddingDist( communication. dim_sum_per_node (List[int]): number of features (sum of dimensions) of the embedding for each host. + emb_dim_per_node_per_feature (List[List[int]]): device (Optional[torch.device]): device on which buffers will be allocated. + qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]): """ def __init__( @@ -477,6 +453,7 @@ def __init__( cross_pg: dist.ProcessGroup, intra_pg: dist.ProcessGroup, dim_sum_per_node: List[int], + emb_dim_per_node_per_feature: List[List[int]], device: Optional[torch.device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: @@ -484,25 +461,33 @@ def __init__( self._rank = rank self._intra_pg: dist.ProcessGroup = intra_pg self._cross_pg: dist.ProcessGroup = cross_pg - - self._intra_dist = PooledEmbeddingsReduceScatter( - intra_pg, - codecs=qcomm_codecs_registry.get( + self._dim_sum_per_node = dim_sum_per_node + self._emb_dim_per_node_per_feature = emb_dim_per_node_per_feature + self._device = device + self._intra_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None ) if qcomm_codecs_registry - else None, + else None ) - self._cross_dist = PooledEmbeddingsAllToAll( - cross_pg, - dim_sum_per_node, - device, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) + self._cross_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get(CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None) if qcomm_codecs_registry - else None, + else None ) + self._intra_dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._cross_dist: Optional[ + Union[ + PooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, + ] + ] = None def forward( self, @@ -519,7 +504,36 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. """ - if sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1: + if self._intra_dist is None or self._cross_dist is None: + self._create_output_dist_modules(sharding_ctx) + local_rank = self._rank % self._intra_pg.size() + current_node = self._rank // self._intra_pg.size() + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + ( + batch_size_per_rank_per_feature_by_cross_group, + batch_size_per_feature_sum_by_cross_group, + ) = self._preprocess_batch_size_per_rank_per_feature( + self._intra_pg.size(), + self._cross_pg.size(), + sharding_ctx.batch_size_per_rank_per_feature, + ) + rs_result = cast( + VariableBatchPooledEmbeddingsReduceScatter, self._intra_dist + )( + local_embs, + batch_size_per_rank_per_feature=batch_size_per_feature_sum_by_cross_group, + embedding_dims=self._emb_dim_per_node_per_feature[current_node], + ).wait() + return cast(VariableBatchPooledEmbeddingsAllToAll, self._cross_dist)( + rs_result, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature_by_cross_group[ + local_rank + ], + batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a, + ) + elif ( + sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1 + ): # preprocess batch_size_per_rank ( batch_size_per_rank_by_cross_group, @@ -530,13 +544,17 @@ def forward( sharding_ctx.batch_size_per_rank, ) # Perform ReduceScatterV within one host - lengths = batch_size_sum_by_cross_group - local_rank = self._rank % self._intra_pg.size() - batch_size_per_rank = batch_size_per_rank_by_cross_group[local_rank] - rs_result = self._intra_dist(local_embs, input_splits=lengths).wait() - return self._cross_dist(rs_result, batch_size_per_rank=batch_size_per_rank) + rs_result = cast(PooledEmbeddingsReduceScatter, self._intra_dist)( + local_embs, input_splits=batch_size_sum_by_cross_group + ).wait() + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + rs_result, + batch_size_per_rank=batch_size_per_rank_by_cross_group[local_rank], + ) else: - return self._cross_dist(self._intra_dist(local_embs).wait()) + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + cast(PooledEmbeddingsReduceScatter, self._intra_dist)(local_embs).wait() + ) def _preprocess_batch_size_per_rank( self, local_size: int, nodes: int, batch_size_per_rank: List[int] @@ -560,10 +578,74 @@ def _preprocess_batch_size_per_rank( return batch_size_per_rank_by_cross_group, batch_size_sum_by_cross_group + def _preprocess_batch_size_per_rank_per_feature( + self, + local_size: int, + nodes: int, + batch_size_per_rank_per_feature_stagger: List[List[int]], + ) -> Tuple[List[List[List[int]]], List[List[int]]]: + """ + Reorders `batch_size_per_rank_per_feature_stagger` so it's aligned with + reordered features after AlltoAll. + """ + if not batch_size_per_rank_per_feature_stagger: + return [[]] * local_size, [] + batch_size_per_rank_per_feature_by_cross_group: List[List[List[int]]] = [] + batch_size_per_feature_sum_by_cross_group: List[List[int]] = [] + for local_rank in range(local_size): + batch_size_by_node_per_rank_per_feature: List[List[int]] = [] + batch_size_per_feature_sum = [0] * len( + batch_size_per_rank_per_feature_stagger[0] + ) + for node in range(nodes): + batch_size = batch_size_per_rank_per_feature_stagger[ + local_rank * nodes + node + ] + batch_size_by_node_per_rank_per_feature.append(batch_size) + batch_size_per_feature_sum = [ + sum(x) for x in zip(batch_size_per_feature_sum, batch_size) + ] + batch_size_per_rank_per_feature_by_cross_group.append( + batch_size_by_node_per_rank_per_feature + ) + batch_size_per_feature_sum_by_cross_group.append(batch_size_per_feature_sum) + + return ( + batch_size_per_rank_per_feature_by_cross_group, + batch_size_per_feature_sum_by_cross_group, + ) + + def _create_output_dist_modules( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature: + self._intra_dist = VariableBatchPooledEmbeddingsReduceScatter( + pg=self._intra_pg, + codecs=self._intra_codecs, + ) + self._cross_dist = VariableBatchPooledEmbeddingsAllToAll( + pg=self._cross_pg, + emb_dim_per_rank_per_feature=self._emb_dim_per_node_per_feature, + device=self._device, + callbacks=None, # don't pass permute callback, handle in LazyAwaitable + codecs=self._cross_codecs, + ) + else: + self._intra_dist = PooledEmbeddingsReduceScatter( + pg=self._intra_pg, + codecs=self._intra_codecs, + ) + self._cross_dist = PooledEmbeddingsAllToAll( + pg=self._cross_pg, + dim_sum_per_rank=self._dim_sum_per_node, + device=self._device, + codecs=self._cross_codecs, + ) + class TwRwPooledEmbeddingSharding( BaseTwRwEmbeddingSharding[ - EmbeddingShardingContext, SparseFeatures, torch.Tensor, torch.Tensor + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ] ): """ @@ -572,28 +654,21 @@ class TwRwPooledEmbeddingSharding( def create_input_dist( self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeatures]: - id_list_features_per_rank = self._features_per_rank( + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + features_per_rank = self._features_per_rank( self._grouped_embedding_configs_per_rank ) - id_score_list_features_per_rank = self._features_per_rank( - self._score_grouped_embedding_configs_per_rank - ) - id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() - id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() + feature_hash_sizes = self._get_feature_hash_sizes() assert self._pg is not None assert self._intra_pg is not None return TwRwSparseFeaturesDist( pg=self._pg, local_size=self._intra_pg.size(), - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - id_list_feature_hash_sizes=id_list_feature_hash_sizes, - id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, + features_per_rank=features_per_rank, + feature_hash_sizes=feature_hash_sizes, device=device if device is not None else self._device, has_feature_processor=self._has_feature_processor, need_pos=self._need_pos, - variable_batch_size=self._variable_batch_size, ) def create_lookup( @@ -604,12 +679,10 @@ def create_lookup( ) -> BaseEmbeddingLookup: return GroupedPooledEmbeddingsLookup( grouped_configs=self._grouped_embedding_configs_per_rank[self._rank], - grouped_score_configs=self._score_grouped_embedding_configs_per_rank[ - self._rank - ], pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + sharding_type=ShardingType.TABLE_ROW_WISE, ) def create_output_dist( @@ -621,6 +694,7 @@ def create_output_dist( cross_pg=cast(dist.ProcessGroup, self._cross_pg), intra_pg=cast(dist.ProcessGroup, self._intra_pg), dim_sum_per_node=self._dim_sum_per_node(), + emb_dim_per_node_per_feature=self._emb_dim_per_node_per_feature(), device=device if device is not None else self._device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) diff --git a/torchrec/distributed/sharding/vb_cw_sharding.py b/torchrec/distributed/sharding/vb_cw_sharding.py deleted file mode 100644 index 30b8e0a9f..000000000 --- a/torchrec/distributed/sharding/vb_cw_sharding.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Dict, Optional - -import torch -from fbgemm_gpu.permute_pooled_embedding_modules_split import ( - PermutePooledEmbeddingsSplit, -) -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup -from torchrec.distributed.embedding_sharding import ( - BaseEmbeddingDist, - BaseEmbeddingLookup, - BaseSparseFeaturesDist, -) -from torchrec.distributed.embedding_types import ( - BaseGroupedFeatureProcessor, - SparseFeatures, -) -from torchrec.distributed.sharding.cw_sharding import BaseCwEmbeddingSharding -from torchrec.distributed.sharding.vb_sharding import VariableBatchShardingContext -from torchrec.distributed.sharding.vb_tw_sharding import ( - VariableBatchTwPooledEmbeddingDist, - VariableBatchTwSparseFeaturesDist, -) - - -class VariableBatchCwPooledEmbeddingSharding( - BaseCwEmbeddingSharding[ - VariableBatchShardingContext, SparseFeatures, torch.Tensor, torch.Tensor - ] -): - """ - Shards embedding bags column-wise, i.e.. a given embedding table is partitioned - along its columns and placed on specified ranks. - - Supports variable batch size. - """ - - def create_input_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: - return VariableBatchTwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup: - return GroupedPooledEmbeddingsLookup( - grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, - pg=self._pg, - device=device if device is not None else self._device, - feature_processor=feature_processor, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor]: - callbacks = None - if self._permute_embeddings and self._embedding_order != list( - range(len(self._embedding_order)) - ): - assert len(self._embedding_order) == len(self._embedding_dims) - embedding_permute_op = PermutePooledEmbeddingsSplit( - self._embedding_dims, - self._embedding_order, - ).to(device=device) - callbacks = [embedding_permute_op] - return VariableBatchTwPooledEmbeddingDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - self._pg, - self._dim_sum_per_rank(), - device if device is not None else self._device, - # pyre-ignore [6] - callbacks, - qcomm_codecs_registry=self.qcomm_codecs_registry, - ) diff --git a/torchrec/distributed/sharding/vb_rw_sharding.py b/torchrec/distributed/sharding/vb_rw_sharding.py deleted file mode 100644 index d437b779d..000000000 --- a/torchrec/distributed/sharding/vb_rw_sharding.py +++ /dev/null @@ -1,258 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -#!/usr/bin/env python3 - -from typing import Any, Dict, List, Optional - -import torch -import torch.distributed as dist -from torchrec.distributed.dist_data import PooledEmbeddingsReduceScatter -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup -from torchrec.distributed.embedding_sharding import ( - BaseEmbeddingDist, - BaseEmbeddingLookup, - BaseSparseFeaturesDist, - bucketize_kjt_before_all2all, - SparseFeaturesAllToAll, -) -from torchrec.distributed.embedding_types import ( - BaseGroupedFeatureProcessor, - SparseFeatures, -) -from torchrec.distributed.sharding.rw_sharding import BaseRwEmbeddingSharding -from torchrec.distributed.sharding.vb_sharding import VariableBatchShardingContext -from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs - - -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - - -class VariableBatchRwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): - """ - Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll - collective operation. - - Supports variable batch size. - - Args: - pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - intra_pg (dist.ProcessGroup): ProcessGroup within single host group for AlltoAll - communication. - num_id_list_features (int): total number of id list features. - num_id_score_list_features (int): total number of id score list features - id_list_feature_hash_sizes (List[int]): hash sizes of id list features. - id_score_list_feature_hash_sizes (List[int]): hash sizes of id score list features. - device (Optional[torch.device]): device on which buffers will be allocated. - has_feature_processor (bool): existence of feature processor (ie. position - weighted features). - """ - - def __init__( - self, - pg: dist.ProcessGroup, - num_id_list_features: int, - num_id_score_list_features: int, - id_list_feature_hash_sizes: List[int], - id_score_list_feature_hash_sizes: List[int], - device: Optional[torch.device] = None, - has_feature_processor: bool = False, - ) -> None: - super().__init__() - self._world_size: int = pg.size() - self._num_id_list_features = num_id_list_features - self._num_id_score_list_features = num_id_score_list_features - id_list_feature_block_sizes = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in id_list_feature_hash_sizes - ] - id_score_list_feature_block_sizes = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in id_score_list_feature_hash_sizes - ] - self.register_buffer( - "_id_list_feature_block_sizes_tensor", - torch.tensor( - id_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), - ) - self.register_buffer( - "_id_score_list_feature_block_sizes_tensor", - torch.tensor( - id_score_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), - ) - self._dist = SparseFeaturesAllToAll( - pg=pg, - id_list_features_per_rank=self._world_size * [self._num_id_list_features], - id_score_list_features_per_rank=self._world_size - * [self._num_id_score_list_features], - device=device, - variable_batch_size=True, - ) - self._has_feature_processor = has_feature_processor - self.unbucketize_permute_tensor: Optional[torch.Tensor] = None - - def forward( - self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: - """ - Bucketizes sparse feature values into world size number of buckets, and then - performs AlltoAll operation. - - Args: - sparse_features (SparseFeatures): sparse features to bucketize and - redistribute. - - Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. - """ - - if self._num_id_list_features > 0: - assert sparse_features.id_list_features is not None - ( - id_list_features, - self.unbucketize_permute_tensor, - ) = bucketize_kjt_before_all2all( - sparse_features.id_list_features, - num_buckets=self._world_size, - block_sizes=self._id_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=self._has_feature_processor, - ) - else: - id_list_features = None - - if self._num_id_score_list_features > 0: - assert sparse_features.id_score_list_features is not None - id_score_list_features, _ = bucketize_kjt_before_all2all( - sparse_features.id_score_list_features, - num_buckets=self._world_size, - block_sizes=self._id_score_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=False, - ) - else: - id_score_list_features = None - - bucketized_sparse_features = SparseFeatures( - id_list_features=id_list_features, - id_score_list_features=id_score_list_features, - ) - return self._dist(bucketized_sparse_features) - - -class VariableBatchRwEmbeddingDistAwaitable(Awaitable[torch.Tensor]): - def __init__(self, awaitable: Awaitable[torch.Tensor], batch_size: int) -> None: - super().__init__() - self._awaitable = awaitable - self._batch_size = batch_size - - def _wait_impl(self) -> torch.Tensor: - embedding = self._awaitable.wait() - - return embedding - - -class VariableBatchRwPooledEmbeddingDist( - BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor] -): - def __init__( - self, - pg: dist.ProcessGroup, - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - ) -> None: - super().__init__() - self._workers: int = pg.size() - self._rank: int = pg.rank() - self._dist = PooledEmbeddingsReduceScatter( - pg, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None - ) - if qcomm_codecs_registry - else None, - ) - - def forward( - self, - local_embs: torch.Tensor, - sharding_ctx: Optional[VariableBatchShardingContext] = None, - ) -> Awaitable[torch.Tensor]: - assert sharding_ctx is not None - batch_size_per_rank = sharding_ctx.batch_size_per_rank - batch_size = batch_size_per_rank[self._rank] - - awaitable_tensor = self._dist( - local_embs.view(sum(batch_size_per_rank), -1), - input_splits=batch_size_per_rank, - ) - return VariableBatchRwEmbeddingDistAwaitable(awaitable_tensor, batch_size) - - -class VariableBatchRwPooledEmbeddingSharding( - BaseRwEmbeddingSharding[ - VariableBatchShardingContext, SparseFeatures, torch.Tensor, torch.Tensor - ] -): - """ - Shards pooled embeddings row-wise, i.e.. a given embedding table is evenly - distributed by rows and table slices are placed on all ranks. - - Supports variable batch size. - """ - - def create_input_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: - num_id_list_features = self._get_id_list_features_num() - num_id_score_list_features = self._get_id_score_list_features_num() - id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() - id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() - return VariableBatchRwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - pg=self._pg, - num_id_list_features=num_id_list_features, - num_id_score_list_features=num_id_score_list_features, - id_list_feature_hash_sizes=id_list_feature_hash_sizes, - id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, - device=self._device, - has_feature_processor=self._has_feature_processor, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup: - return GroupedPooledEmbeddingsLookup( - grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, - pg=self._pg, - device=device if device is not None else self._device, - feature_processor=feature_processor, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor]: - return VariableBatchRwPooledEmbeddingDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - self._pg, - qcomm_codecs_registry=self.qcomm_codecs_registry, - ) diff --git a/torchrec/distributed/sharding/vb_sharding.py b/torchrec/distributed/sharding/vb_sharding.py deleted file mode 100644 index c738fb92f..000000000 --- a/torchrec/distributed/sharding/vb_sharding.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field -from typing import List, Optional - -import torch -from torchrec.streamable import Multistreamable - - -@dataclass -class VariableBatchShardingContext(Multistreamable): - """ - For variable batch size case, we need pass `batch_size_per_rank` to - PooledEmbeddingsAllToAll and it can be retrieved from SparseFeaturesAllToAll. - - Attributes: - batch_size_per_rank (List[int]): stores batch size in each rank. - batch_size_per_rank_tensor (Optional[torch.Tensor]): batch_size_per_rank stored in - tensor. - """ - - batch_size_per_rank: List[int] = field(default_factory=list) - batch_size_per_rank_tensor: Optional[torch.Tensor] = None - - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - if self.batch_size_per_rank_tensor is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. - self.batch_size_per_rank_tensor.record_stream(stream) diff --git a/torchrec/distributed/sharding/vb_tw_sharding.py b/torchrec/distributed/sharding/vb_tw_sharding.py deleted file mode 100644 index 986e9e972..000000000 --- a/torchrec/distributed/sharding/vb_tw_sharding.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Callable, Dict, List, Optional - -import torch -import torch.distributed as dist -from torchrec.distributed.dist_data import PooledEmbeddingsAllToAll -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup -from torchrec.distributed.embedding_sharding import ( - BaseEmbeddingDist, - BaseEmbeddingLookup, - BaseSparseFeaturesDist, - SparseFeaturesAllToAll, -) -from torchrec.distributed.embedding_types import ( - BaseGroupedFeatureProcessor, - SparseFeatures, -) -from torchrec.distributed.sharding.tw_sharding import BaseTwEmbeddingSharding -from torchrec.distributed.sharding.vb_sharding import VariableBatchShardingContext -from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs - - -class VariableBatchTwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): - """ - Redistributes sparse features in TW fashion with an AlltoAll collective - operation. - - Supports variable batch size. - - Args: - pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - id_list_features_per_rank (List[int]): number of id list features to send to - each rank. - id_score_list_features_per_rank (List[int]): number of id score list features to - send to each rank - device (Optional[torch.device]): device on which buffers will be allocated. - """ - - def __init__( - self, - pg: dist.ProcessGroup, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - device: Optional[torch.device] = None, - ) -> None: - super().__init__() - self._dist = SparseFeaturesAllToAll( - pg=pg, - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - device=device, - variable_batch_size=True, - ) - - def forward( - self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: - """ - Performs AlltoAll operation on sparse features. - - Args: - sparse_features (SparseFeatures): sparse features to redistribute. - - Returns: - Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures. - """ - - return self._dist(sparse_features) - - -class VariableBatchTwPooledEmbeddingDist( - BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor] -): - def __init__( - self, - pg: dist.ProcessGroup, - dim_sum_per_rank: List[int], - device: Optional[torch.device] = None, - callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - ) -> None: - super().__init__() - self._dist = PooledEmbeddingsAllToAll( - pg, - dim_sum_per_rank, - device, - callbacks, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, - ) - - def forward( - self, - local_embs: torch.Tensor, - sharding_ctx: Optional[VariableBatchShardingContext] = None, - ) -> Awaitable[torch.Tensor]: - assert sharding_ctx is not None - # do not remove the keyword for quantized communication hook injection. - return self._dist( - local_embs, batch_size_per_rank=sharding_ctx.batch_size_per_rank - ) - - -class VariableBatchTwPooledEmbeddingSharding( - BaseTwEmbeddingSharding[ - VariableBatchShardingContext, SparseFeatures, torch.Tensor, torch.Tensor - ] -): - """ - Shards pooled embeddings table-wise, i.e.. a given embedding table is entirely placed - on a selected rank. - - Supports variable batch size. - """ - - def create_input_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[SparseFeatures]: - return VariableBatchTwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - self._pg, - self.id_list_features_per_rank(), - self.id_score_list_features_per_rank(), - device if device is not None else self._device, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup: - return GroupedPooledEmbeddingsLookup( - grouped_configs=self._grouped_embedding_configs, - grouped_score_configs=self._score_grouped_embedding_configs, - pg=self._pg, - device=device if device is not None else self._device, - feature_processor=feature_processor, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor]: - return VariableBatchTwPooledEmbeddingDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - self._pg, - self._dim_sum_per_rank(), - device if device is not None else self._device, - qcomm_codecs_registry=self.qcomm_codecs_registry, - ) diff --git a/torchrec/distributed/sharding/vb_twrw_sharding.py b/torchrec/distributed/sharding/vb_twrw_sharding.py deleted file mode 100644 index 7c45f02ef..000000000 --- a/torchrec/distributed/sharding/vb_twrw_sharding.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -#!/usr/bin/env python3 - -import itertools -import math -from typing import Any, cast, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist -from torchrec.distributed.dist_data import ( - PooledEmbeddingsAllToAll, - PooledEmbeddingsReduceScatter, -) -from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup -from torchrec.distributed.embedding_sharding import ( - BaseEmbeddingDist, - BaseEmbeddingLookup, - BaseSparseFeaturesDist, - bucketize_kjt_before_all2all, - SparseFeaturesAllToAll, -) -from torchrec.distributed.embedding_types import ( - BaseGroupedFeatureProcessor, - SparseFeatures, -) -from torchrec.distributed.sharding.twrw_sharding import BaseTwRwEmbeddingSharding -from torchrec.distributed.sharding.vb_sharding import VariableBatchShardingContext -from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs - -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - - -class VariableBatchTwRwSparseFeaturesDist(BaseSparseFeaturesDist[SparseFeatures]): - """ - Bucketizes sparse features in TWRW fashion and then redistributes with an AlltoAll - collective operation. - - Supports variable batch size. - - Args: - pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. - intra_pg (dist.ProcessGroup): ProcessGroup within single host group for AlltoAll - communication. - id_list_features_per_rank (List[int]): number of id list features to send to - each rank. - id_score_list_features_per_rank (List[int]): number of id score list features to - send to each rank - id_list_feature_hash_sizes (List[int]): hash sizes of id list features. - id_score_list_feature_hash_sizes (List[int]): hash sizes of id score list - features. - device (Optional[torch.device]): device on which buffers will be allocated. - has_feature_processor (bool): existence of feature processor (ie. position - weighted features). - - Example:: - - 3 features - 2 hosts with 2 devices each - - Bucketize each feature into 2 buckets - Staggered shuffle with feature splits [2, 1] - AlltoAll operation - - NOTE: result of staggered shuffle and AlltoAll operation look the same after - reordering in AlltoAll - - Result: - host 0 device 0: - feature 0 bucket 0 - feature 1 bucket 0 - - host 0 device 1: - feature 0 bucket 1 - feature 1 bucket 1 - - host 1 device 0: - feature 2 bucket 0 - - host 1 device 1: - feature 2 bucket 1 - """ - - def __init__( - self, - pg: dist.ProcessGroup, - intra_pg: dist.ProcessGroup, - id_list_features_per_rank: List[int], - id_score_list_features_per_rank: List[int], - id_list_feature_hash_sizes: List[int], - id_score_list_feature_hash_sizes: List[int], - device: Optional[torch.device] = None, - has_feature_processor: bool = False, - ) -> None: - super().__init__() - assert ( - pg.size() % intra_pg.size() == 0 - ), "currently group granularity must be node" - - self._world_size: int = pg.size() - self._local_size: int = intra_pg.size() - self._num_cross_nodes: int = self._world_size // self._local_size - id_list_feature_block_sizes = [ - math.ceil(hash_size / self._local_size) - for hash_size in id_list_feature_hash_sizes - ] - id_score_list_feature_block_sizes = [ - math.ceil(hash_size / self._local_size) - for hash_size in id_score_list_feature_hash_sizes - ] - - self._id_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( - id_list_features_per_rank - ) - self._id_score_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( - id_score_list_features_per_rank - ) - self.register_buffer( - "_id_list_feature_block_sizes_tensor", - torch.tensor( - id_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), - ) - self.register_buffer( - "_id_score_list_feature_block_sizes_tensor", - torch.tensor( - id_score_list_feature_block_sizes, - device=device, - dtype=torch.int32, - ), - ) - self.register_buffer( - "_id_list_sf_staggered_shuffle_tensor", - torch.tensor( - self._id_list_sf_staggered_shuffle, - device=device, - dtype=torch.int32, - ), - ) - self.register_buffer( - "_id_score_list_sf_staggered_shuffle_tensor", - torch.tensor( - self._id_score_list_sf_staggered_shuffle, - device=device, - dtype=torch.int32, - ), - ) - self._dist = SparseFeaturesAllToAll( - pg=pg, - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - device=device, - stagger=self._num_cross_nodes, - variable_batch_size=True, - ) - self._has_feature_processor = has_feature_processor - - def forward( - self, - sparse_features: SparseFeatures, - ) -> Awaitable[Awaitable[SparseFeatures]]: - """ - Bucketizes sparse feature values into local world size number of buckets, - performs staggered shuffle on the sparse features, and then performs AlltoAll - operation. - - Args: - sparse_features (SparseFeatures): sparse features to bucketize and - redistribute. - - Returns: - Awaitable[SparseFeatures]: awaitable of SparseFeatures. - """ - - bucketized_sparse_features = SparseFeatures( - id_list_features=bucketize_kjt_before_all2all( - sparse_features.id_list_features, - num_buckets=self._local_size, - block_sizes=self._id_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=self._has_feature_processor, - )[0].permute( - self._id_list_sf_staggered_shuffle, - self._id_list_sf_staggered_shuffle_tensor, - ) - if sparse_features.id_list_features is not None - else None, - id_score_list_features=bucketize_kjt_before_all2all( - sparse_features.id_score_list_features, - num_buckets=self._local_size, - block_sizes=self._id_score_list_feature_block_sizes_tensor, - output_permute=False, - bucketize_pos=False, - )[0].permute( - self._id_score_list_sf_staggered_shuffle, - self._id_score_list_sf_staggered_shuffle_tensor, - ) - if sparse_features.id_score_list_features is not None - else None, - ) - return self._dist(bucketized_sparse_features) - - def _staggered_shuffle(self, features_per_rank: List[int]) -> List[int]: - """ - Reorders sparse data such that data is in contiguous blocks and correctly - ordered for global TWRW layout. - """ - - nodes = self._world_size // self._local_size - features_per_node = [ - features_per_rank[node * self._local_size] for node in range(nodes) - ] - node_offsets = [0] + list(itertools.accumulate(features_per_node)) - num_features = node_offsets[-1] - - return [ - bucket * num_features + feature - for node in range(nodes) - for bucket in range(self._local_size) - for feature in range(node_offsets[node], node_offsets[node + 1]) - ] - - -class VariableBatchTwRwPooledEmbeddingDist( - BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor] -): - def __init__( - self, - rank: int, - cross_pg: dist.ProcessGroup, - intra_pg: dist.ProcessGroup, - dim_sum_per_node: List[int], - device: Optional[torch.device] = None, - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - ) -> None: - super().__init__() - self._rank = rank - self._intra_pg: dist.ProcessGroup = intra_pg - self._cross_pg: dist.ProcessGroup = cross_pg - self._device: Optional[torch.device] = device - self._intra_dist = PooledEmbeddingsReduceScatter( - intra_pg, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None - ) - if qcomm_codecs_registry - else None, - ) - self._cross_dist = PooledEmbeddingsAllToAll( - cross_pg, - dim_sum_per_node, - device, - codecs=qcomm_codecs_registry.get( - CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None - ) - if qcomm_codecs_registry - else None, - ) - - def forward( - self, - local_embs: torch.Tensor, - sharding_ctx: Optional[VariableBatchShardingContext] = None, - ) -> Awaitable[torch.Tensor]: - assert sharding_ctx is not None - # preprocess batch_size_per_rank - ( - batch_size_per_rank_by_cross_group, - batch_size_sum_by_cross_group, - ) = self._preprocess_batch_size_per_rank( - self._intra_pg.size(), - self._cross_pg.size(), - sharding_ctx.batch_size_per_rank, - ) - - # Perform ReduceScatterV within one host - lengths = batch_size_sum_by_cross_group - rs_result = self._intra_dist( - local_embs.view(sum(lengths), -1), input_splits=lengths - ).wait() - - local_rank = self._rank % self._intra_pg.size() - - return self._cross_dist( - rs_result, - batch_size_per_rank=batch_size_per_rank_by_cross_group[local_rank], - ) - - def _preprocess_batch_size_per_rank( - self, local_size: int, nodes: int, batch_size_per_rank: List[int] - ) -> Tuple[List[List[int]], List[int]]: - """ - Reorders `batch_size_per_rank` so it's aligned with reordered features after - AlltoAll. - """ - batch_size_per_rank_by_cross_group: List[List[int]] = [] - batch_size_sum_by_cross_group: List[int] = [] - for local_rank in range(local_size): - batch_size_per_rank_: List[int] = [] - batch_size_sum = 0 - for node in range(nodes): - batch_size_per_rank_.append( - batch_size_per_rank[local_rank + node * local_size] - ) - batch_size_sum += batch_size_per_rank[local_rank + node * local_size] - batch_size_per_rank_by_cross_group.append(batch_size_per_rank_) - batch_size_sum_by_cross_group.append(batch_size_sum) - - return batch_size_per_rank_by_cross_group, batch_size_sum_by_cross_group - - -class VariableBatchTwRwPooledEmbeddingSharding( - BaseTwRwEmbeddingSharding[ - VariableBatchShardingContext, SparseFeatures, torch.Tensor, torch.Tensor - ] -): - """ - Shards embedding bags table-wise then row-wise. - - Supports variable batch size. - """ - - def create_input_dist( - self, device: Optional[torch.device] = None - ) -> BaseSparseFeaturesDist[SparseFeatures]: - id_list_features_per_rank = self._features_per_rank( - self._grouped_embedding_configs_per_rank - ) - id_score_list_features_per_rank = self._features_per_rank( - self._score_grouped_embedding_configs_per_rank - ) - id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() - id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() - return VariableBatchTwRwSparseFeaturesDist( - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - pg=self._pg, - intra_pg=cast(dist.ProcessGroup, self._intra_pg), - id_list_features_per_rank=id_list_features_per_rank, - id_score_list_features_per_rank=id_score_list_features_per_rank, - id_list_feature_hash_sizes=id_list_feature_hash_sizes, - id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, - device=device if device is not None else self._device, - has_feature_processor=self._has_feature_processor, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup: - return GroupedPooledEmbeddingsLookup( - grouped_configs=self._grouped_embedding_configs_per_rank[self._rank], - grouped_score_configs=self._score_grouped_embedding_configs_per_rank[ - self._rank - ], - pg=self._pg, - device=device if device is not None else self._device, - feature_processor=feature_processor, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[VariableBatchShardingContext, torch.Tensor, torch.Tensor]: - return VariableBatchTwRwPooledEmbeddingDist( - rank=self._rank, - cross_pg=cast(dist.ProcessGroup, self._cross_pg), - intra_pg=cast(dist.ProcessGroup, self._intra_pg), - dim_sum_per_node=self._dim_sum_per_node(), - device=device if device is not None else self._device, - qcomm_codecs_registry=self.qcomm_codecs_registry, - ) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 68d2fd728..27b011300 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -5,10 +5,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import math -from typing import Callable, cast, Dict, List, Optional, Tuple, Type +import warnings +from typing import Callable, cast, Dict, List, Optional, Tuple, Type, Union import torch from torch import distributed as dist, nn @@ -16,29 +19,49 @@ from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fp_embeddingbag import ( + FeatureProcessedEmbeddingBagCollectionSharder, +) from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder -from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder +from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder +from torchrec.distributed.mc_embeddingbag import ( + ManagedCollisionEmbeddingBagCollectionSharder, +) +from torchrec.distributed.mc_modules import InferManagedCollisionCollectionSharder +from torchrec.distributed.planner.constants import MIN_CW_DIM +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollectionSharder, + QuantManagedCollisionEmbeddingCollectionSharder, +) from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, EnumerableShardingSpec, ModuleSharder, - ModuleShardingPlan, ParameterSharding, ShardingType, ShardMetadata, ) from torchrec.distributed.utils import none_throws -MIN_CW_DIM: int = 128 - def get_default_sharders() -> List[ModuleSharder[nn.Module]]: return [ cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder()), + cast(ModuleSharder[nn.Module], FeatureProcessedEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], EmbeddingCollectionSharder()), cast(ModuleSharder[nn.Module], FusedEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], QuantEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], QuantEmbeddingCollectionSharder()), + cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingBagCollectionSharder()), + cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingCollectionSharder()), + cast( + ModuleSharder[nn.Module], + QuantManagedCollisionEmbeddingCollectionSharder( + QuantEmbeddingCollectionSharder(), + InferManagedCollisionCollectionSharder(), + ), + ), ] @@ -52,17 +75,27 @@ def placement( local_size: int, ) -> str: param_device = compute_device - if compute_device == "cuda": - param_device = torch.device("cuda", rank % local_size) + if compute_device in {"cuda", "mtia"}: + param_device = torch.device(compute_device, rank % local_size) return f"rank:{rank}/{param_device}" +# TODO: Consolidate placement and placement_helper into one function. +def placement_helper(device_type: str, index: int = 0, rank: int = 0) -> str: + if device_type == "cpu": + return f"rank:0/{device_type}" # cpu only use rank 0 + + result = f"rank:{rank}/{device_type}:{index}" + return result + + def calculate_shard_sizes_and_offsets( tensor: torch.Tensor, world_size: int, local_world_size: int, sharding_type: str, col_wise_shard_dim: Optional[int] = None, + device_memory_sizes: Optional[List[int]] = None, ) -> Tuple[List[List[int]], List[List[int]]]: """ Calculates sizes and offsets for tensor sharded according to provided sharding type. @@ -88,7 +121,13 @@ def calculate_shard_sizes_and_offsets( elif sharding_type == ShardingType.TABLE_WISE.value: return [[rows, columns]], [[0, 0]] elif sharding_type == ShardingType.ROW_WISE.value: - return _calculate_rw_shard_sizes_and_offsets(rows, world_size, columns) + return ( + _calculate_rw_shard_sizes_and_offsets(rows, world_size, columns) + if not device_memory_sizes + else _calculate_uneven_rw_shard_sizes_and_offsets( + rows, world_size, columns, device_memory_sizes + ) + ) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: return _calculate_rw_shard_sizes_and_offsets(rows, local_world_size, columns) elif ( @@ -96,19 +135,47 @@ def calculate_shard_sizes_and_offsets( or sharding_type == ShardingType.TABLE_COLUMN_WISE.value ): return _calculate_cw_shard_sizes_and_offsets(columns, rows, col_wise_shard_dim) + elif sharding_type == ShardingType.GRID_SHARD.value: + return _calculate_grid_shard_sizes_and_offsets( + rows, local_world_size, columns, col_wise_shard_dim + ) raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" ) +def _calculate_grid_shard_sizes_and_offsets( + hash_size: int, + num_device: int, + columns: int, + col_wise_shard_dim: Optional[int] = None, +) -> Tuple[List[List[int]], List[List[int]]]: + """ + Similar to row-wise case, but also splits columns into blocks of size `col_wise_shard_dim`. + """ + row_shard_sizes, row_shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size, num_device, columns + ) + block_size = _get_block_size_for_cw_shard(columns, col_wise_shard_dim) + num_col_wise_nodes, _residual = divmod(columns, block_size) + shard_sizes: List[List[int]] = [] + shard_offsets: List[List[int]] = [] + + for node in range(num_col_wise_nodes): + for row_shard_size, row_shard_offset in zip(row_shard_sizes, row_shard_offsets): + shard_sizes.append([row_shard_size[0], block_size]) + shard_offsets.append([row_shard_offset[0], block_size * node]) + return shard_sizes, shard_offsets + + def _calculate_rw_shard_sizes_and_offsets( hash_size: int, num_devices: int, columns: int ) -> Tuple[List[List[int]], List[List[int]]]: """ - Sets prefix of shard_sizes to be ceil(hash_size/num_devices). + Sets prefix of shard_sizes to be `math.ceil(hash_size/num_devices)`. - For example if hash_size = 10, num_devices = 3, we will allocate the rows as 3,3,3,1 + For example if hash_size = 10, num_devices = 4, we will allocate the rows as 3,3,3,1 (rather than 3,3,2,2). This is due to implementation in RW sharding that sets block_size_lists to be ceil. The balanced way is harder to support on GPU. @@ -139,19 +206,87 @@ def _calculate_rw_shard_sizes_and_offsets( return shard_sizes, shard_offsets +def _calculate_uneven_rw_shard_sizes_and_offsets( + hash_size: int, num_devices: int, columns: int, device_memory_sizes: List[int] +) -> Tuple[List[List[int]], List[List[int]]]: + assert num_devices == len(device_memory_sizes), "must provide all the memory size" + total_size = sum(device_memory_sizes) + shard_sizes: List[List[int]] = [] + last_rank = num_devices - 1 + + processed_total_rows = 0 + + for rank in range(num_devices): + if rank < last_rank: + local_row: int = int(hash_size * (device_memory_sizes[rank] / total_size)) + processed_total_rows += local_row + elif rank == last_rank: + local_row: int = hash_size - processed_total_rows + else: + local_row: int = 0 + shard_sizes.append([local_row, columns]) + shard_offsets = [[0, 0]] + + for i in range(num_devices - 1): + shard_offsets.append([shard_sizes[i][0] + shard_offsets[i][0], 0]) + + return shard_sizes, shard_offsets + + +def _find_base_dim(lower_bound: int, dim: int) -> int: + for i in range(lower_bound, dim): + if dim % i == 0 and i % 4 == 0: + return i + return dim + + +def _get_block_size_for_cw_shard( + columns: int, column_wise_shard_dim: Optional[int] +) -> int: + block_size: int = min( + ( + _find_base_dim(column_wise_shard_dim, columns) + if column_wise_shard_dim + else _find_base_dim(MIN_CW_DIM, columns) + ), + columns, + ) + + if columns % block_size != 0: + warnings.warn( + f"Dim of {columns} cannot be evenly divided with column wise shard" + "dim {column_wise_shard_dim}, overriding block_size to embedding_dim={columns}", + UserWarning, + ) + block_size = columns + return block_size + + def _calculate_cw_shard_sizes_and_offsets( - hash_size: int, + columns: int, rows: int, col_wise_shard_dim: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: block_size: int = min( - col_wise_shard_dim if col_wise_shard_dim else MIN_CW_DIM, hash_size + ( + _find_base_dim(col_wise_shard_dim, columns) + if col_wise_shard_dim + else _find_base_dim(MIN_CW_DIM, columns) + ), + columns, ) - num_col_wise_shards, residual = divmod(hash_size, block_size) - shard_sizes: List[List[int]] = [[rows, block_size]] * (num_col_wise_shards - 1) - shard_sizes.append([rows, block_size + residual]) + if columns % block_size != 0: + warnings.warn( + f"Dim of {columns} cannot be evenly divided with column wise shard" + "dim {col_wise_shard_dim}, overriding block_size to embedding_dim={columns}", + UserWarning, + ) + block_size = columns + + num_col_wise_shards, _residual = divmod(columns, block_size) + shard_sizes: List[List[int]] = [[rows, block_size]] * num_col_wise_shards shard_offsets: List[List[int]] = [ [0, block_size * rank] for rank in range(num_col_wise_shards) ] @@ -165,7 +300,10 @@ def _get_parameter_size_offsets( world_size: int, col_wise_shard_dim: Optional[int] = None, ) -> List[Tuple[List[int], List[int]]]: - (shard_sizes, shard_offsets,) = calculate_shard_sizes_and_offsets( + ( + shard_sizes, + shard_offsets, + ) = calculate_shard_sizes_and_offsets( tensor=none_throws(param), world_size=world_size, local_world_size=local_size, @@ -182,22 +320,36 @@ def _get_compute_kernel( device_type: str, ) -> str: # TODO add placement support for compute_kernel - compute_kernels = sharder.compute_kernels(sharding_type, device_type) + compute_kernels = [EmbeddingComputeKernel.DENSE.value] + if sharding_type != ShardingType.DATA_PARALLEL.value: + compute_kernels += [ + EmbeddingComputeKernel.FUSED.value, + ] + if device_type in {"cuda"}: + compute_kernels += [ + EmbeddingComputeKernel.FUSED_UVM.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] - if sharding_type == ShardingType.DATA_PARALLEL or not hasattr( - param, "_optimizer_class" - ): + if sharding_type == ShardingType.DATA_PARALLEL.value: if EmbeddingComputeKernel.DENSE.value in compute_kernels: return EmbeddingComputeKernel.DENSE.value elif EmbeddingComputeKernel.QUANT.value in compute_kernels: return EmbeddingComputeKernel.QUANT.value else: - if EmbeddingComputeKernel.FUSED.value in compute_kernels: + if ( + hasattr(param, "_in_backward_optimizers") + and EmbeddingComputeKernel.FUSED.value in compute_kernels + ): return EmbeddingComputeKernel.FUSED.value + elif EmbeddingComputeKernel.DENSE.value in compute_kernels: + return EmbeddingComputeKernel.DENSE.value elif EmbeddingComputeKernel.QUANT.value in compute_kernels: return EmbeddingComputeKernel.QUANT.value - raise ValueError(f"Could not find compute kernel for sharding_type={sharding_type}") + raise ValueError( + f"Could not find compute kernel for sharding_type={sharding_type} in {compute_kernels}" + ) def _get_parameter_sharding( @@ -207,32 +359,54 @@ def _get_parameter_sharding( local_size: int, device_type: str, sharder: ModuleSharder[nn.Module], + placements: Optional[List[str]] = None, + compute_kernel: Optional[str] = None, ) -> ParameterSharding: return ParameterSharding( - sharding_spec=None - if sharding_type == ShardingType.DATA_PARALLEL.value - else EnumerableShardingSpec( - [ - ShardMetadata( - shard_sizes=size, - shard_offsets=offset, - placement=placement( - device_type, - none_throws(rank), - none_throws(local_size), - ), - ) - for (size, offset, rank) in (size_offset_ranks) - ] + sharding_spec=( + None + if sharding_type == ShardingType.DATA_PARALLEL.value + else EnumerableShardingSpec( + [ + ShardMetadata( + shard_sizes=size, + shard_offsets=offset, + placement=( + placement( + device_type, + none_throws(rank), + none_throws(local_size), + ) + if not device_placement + else device_placement + ), + ) + for (size, offset, rank), device_placement in zip( + size_offset_ranks, + placements if placements else [None] * len(size_offset_ranks), + ) + ] + ) ), sharding_type=sharding_type, - compute_kernel=_get_compute_kernel(sharder, param, sharding_type, device_type), + compute_kernel=( + compute_kernel + if compute_kernel + else _get_compute_kernel(sharder, param, sharding_type, device_type) + ), ranks=[rank for (_, _, rank) in size_offset_ranks], ) ParameterShardingGenerator = Callable[ - [nn.Parameter, int, int, str, ModuleSharder[nn.Module]], ParameterSharding + [ + nn.Parameter, + int, + int, + str, + ModuleSharder[nn.Module], + ], + ParameterSharding, ] @@ -284,12 +458,16 @@ def _parameter_sharding_generator( def table_wise( rank: int, + device: Optional[str] = None, + compute_kernel: Optional[str] = None, ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::TABLE_WISE` for construct_module_sharding_plan. Args: rank (int): rank to place table when doing table wise + device (Optional[str]): device to place table when doing table_wise sharding + compute_kernel (Optional[str]): embedding compute kernel to use for the table Example:: @@ -326,15 +504,22 @@ def _parameter_sharding_generator( local_size, device_type, sharder, + placements=([placement_helper(device, rank, rank)] if device else None), + compute_kernel=compute_kernel, ) return _parameter_sharding_generator -def row_wise() -> ParameterShardingGenerator: +def row_wise( + sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None +) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan. + Args: + sizes_placement (Optional[Tuple[List[int], str]]): Only use it in inference for uneven shardinglist of tuples of (sizes, placement); sizes is the row size list + Example:: ebc = EmbeddingBagCollection(...) @@ -342,10 +527,17 @@ def row_wise() -> ParameterShardingGenerator: ebc, { "table_1": row_wise(), + "table_2": row_wise([10, 5, 0, 3], "cpu") }, ) """ + if sizes_placement is not None and isinstance(sizes_placement[1], list): + assert len(sizes_placement[0]) == len( + sizes_placement[1] + ), "sizes_placement and device per placement (in case of sharding " + "across HBM and CPU host) must have the same length" + def _parameter_sharding_generator( param: nn.Parameter, local_size: int, @@ -353,16 +545,50 @@ def _parameter_sharding_generator( device_type: str, sharder: ModuleSharder[nn.Module], ) -> ParameterSharding: - size_and_offsets = _get_parameter_size_offsets( - param, - ShardingType.ROW_WISE, - local_size, - world_size, - ) - assert len(size_and_offsets) <= world_size - size_offset_ranks = [] - for (size, offset), rank in zip(size_and_offsets, range(world_size)): - size_offset_ranks.append((size, offset, rank)) + if sizes_placement is None: + size_and_offsets = _get_parameter_size_offsets( + param, + ShardingType.ROW_WISE, + local_size, + world_size, + ) + assert len(size_and_offsets) <= world_size + size_offset_ranks = [] + for (size, offset), rank in zip(size_and_offsets, range(world_size)): + size_offset_ranks.append((size, offset, rank)) + else: + size_offset_ranks = [] + sizes = sizes_placement[0] + (rows, cols) = param.shape + cur_offset = 0 + prev_offset = 0 + for rank, size in enumerate(sizes): + per_rank_row = size + cur_offset += per_rank_row + cur_offset = min(cur_offset, rows) + per_rank_row = cur_offset - prev_offset + size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank)) + prev_offset = cur_offset + + if cur_offset < rows: + raise ValueError( + f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {sizes_placement}" + ) + + index: int = 0 + placements: List[str] = [] + if sizes_placement is not None: + device_type = "" + for i in range(len(sizes_placement[0])): + if isinstance(sizes_placement[1], list): + device_type = sizes_placement[1][i] + placements.append(placement_helper(device_type, index, i)) + else: + device_type = str(sizes_placement[1]) + placements.append(placement_helper(device_type, index, i)) + + if device_type == "cuda": + index += 1 return _get_parameter_sharding( param, @@ -371,6 +597,10 @@ def _parameter_sharding_generator( local_size, device_type, sharder, + placements=placements if sizes_placement else None, + compute_kernel=( + EmbeddingComputeKernel.QUANT.value if sizes_placement else None + ), ) return _parameter_sharding_generator @@ -485,6 +715,58 @@ def _parameter_sharding_generator( return _parameter_sharding_generator +def grid_shard( + host_indexes: List[int], +) -> ParameterShardingGenerator: + """ + Returns a generator of ParameterShardingPlan for `ShardingType::GRID_SHARD` for construct_module_sharding_plan. + + Args: + host_indexes (List[int]): index of hosts (nodes) to do row wise + + Example:: + + ebc = EmbeddingBagCollection(...) + plan = construct_module_sharding_plan( + ebc, + { + "table_4": grid_shard(host_indexes=[1,2]), + }, + ) + """ + + def _parameter_sharding_generator( + param: nn.Parameter, + local_size: int, + world_size: int, + device_type: str, + sharder: ModuleSharder[nn.Module], + ) -> ParameterSharding: + size_and_offsets = _get_parameter_size_offsets( + param, + ShardingType.GRID_SHARD, + local_size, + world_size, + ) + size_offset_ranks = [] + for host_count, host_index in enumerate(host_indexes): + for rank in range(local_size): + (size, offset) = size_and_offsets[host_count * local_size + rank] + rank_offset = host_index * local_size + size_offset_ranks.append((size, offset, rank_offset + rank)) + + return _get_parameter_sharding( + param, + ShardingType.GRID_SHARD.value, + size_offset_ranks, + local_size, + device_type, + sharder, + ) + + return _parameter_sharding_generator + + def apply_to_all( module: nn.Module, parameter_sharding_generator: ParameterShardingGenerator, @@ -503,8 +785,7 @@ def apply_to_all( ) """ if sharder is None: - # pyre-ignore - sharder = get_module_to_default_sharders.get(type(module), None) + sharder = get_module_to_default_sharders().get(type(module), None) else: assert isinstance( module, sharder.module_type @@ -526,10 +807,10 @@ def construct_module_sharding_plan( sharder: Optional[ModuleSharder[nn.Module]] = None, local_size: Optional[int] = None, world_size: Optional[int] = None, - device_type: str = "cuda", -) -> ModuleShardingPlan: + device_type: Optional[str] = None, +) -> EmbeddingModuleShardingPlan: """ - Helper function to create module sharding plans (ModuleShardingPlan) for an module + Helper function to create module sharding plans (EmbeddingModuleShardingPlan) for an module Args: module (nn.Module): module to create plan for. per_param_sharding: Dict[str, Callable[[nn.Parameter, int, int, str], ParameterSharding]]: A mapping of parameter names to a generator function @@ -554,6 +835,8 @@ def construct_module_sharding_plan( }, ) """ + if device_type is None: + device_type = "cuda" if torch.cuda.is_available() else "cpu" if sharder is None: sharder = get_module_to_default_sharders().get(type(module), None) assert ( @@ -562,16 +845,17 @@ def construct_module_sharding_plan( assert isinstance( module, sharder.module_type - ), f"Incorrect sharder for module type {type(module)}" + ), f"Incorrect sharder {type(sharder)} for module type {type(module)}" shardable_parameters = sharder.shardable_parameters(module) - assert ( - shardable_parameters.keys() == per_param_sharding.keys() - ), "per_param_sharding_config doesn't match the shardable parameters of the module" + assert shardable_parameters.keys() == per_param_sharding.keys(), ( + "per_param_sharding_config doesn't match the shardable parameters of the module," + f"got {list(shardable_parameters.keys())} != {list(per_param_sharding.keys())}" + ) local_size = local_size or get_local_size() world_size = world_size or dist.get_world_size() - per_parameter_sharding: ModuleShardingPlan = {} + per_parameter_sharding = EmbeddingModuleShardingPlan() for table_name, sharding_plan_generator in per_param_sharding.items(): param = shardable_parameters[table_name] per_parameter_sharding[table_name] = sharding_plan_generator( diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py new file mode 100644 index 000000000..e7fc1e52b --- /dev/null +++ b/torchrec/distributed/shards_wrapper.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +# COPY of the code from torch.distributed._tensor._shards_wrapper - for package compat + +from typing import Any, List, Tuple + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + MetadataIndex, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + TensorWriteData, + WriteItem, + WriteItemType, +) + +aten = torch.ops.aten # pyre-ignore[5] + + +class LocalShardsWrapper(torch.Tensor): + """ + A wrapper class to hold local shards of a DTensor. + This class is used largely for checkpointing purposes and implicity subtypes + the _Checkpointable protocol. + """ + + __slots__ = ["_local_shards", "_storage_meta"] + # pyre-fixme[13]: Attribute `_local_shards` is never initialized. + _local_shards: List[torch.Tensor] + # pyre-fixme[13]: Attribute `_storage_meta` is never initialized. + _storage_meta: TensorStorageMetadata + + @staticmethod + def __new__( + cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]] + ) -> "LocalShardsWrapper": + assert all( + tensor.device == local_shards[0].device for tensor in local_shards[1:] + ) + + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]) + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + + # we calculate the total tensor size by "concat" on second tensor dimension + cat_tensor_shape = list(local_shards[0].size()) + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[1] += shard.size()[1] + + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) + wrapper_shape = torch.Size(cat_tensor_shape) + chunks_meta = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for shard, offset in zip(local_shards, local_offsets) + ] + + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + torch.Size(cat_tensor_shape), + ) + r._local_shards = local_shards + r._storage_meta = TensorStorageMetadata( + properties=wrapper_properties, + size=wrapper_shape, + chunks=chunks_meta, + ) + + return r + + # necessary for ops dispatching from this subclass to its local shards + @classmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + dispatcher = { + torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, + torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, + aten._to_copy.default: cls.handle_to_copy, + aten.view.default: cls.handle_view, + aten.equal.default: cls.handle_equal, + aten.detach.default: cls.handle_detach, + aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, + } + + if func in dispatcher: + return dispatcher[func](args, kwargs) # pyre-ignore [29] + else: + raise NotImplementedError( + f"{func} is not supported for LocalShardsWrapper!" + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_all_gather_into_tensor(args, kwargs): + dim = args[0].local_sizes()[0][1] + cat_tensor = torch.cat( + [t.view(-1) for t in args[0].local_shards()], dim=0 + ).view(-1, dim) + return torch.ops._c10d_functional.all_gather_into_tensor.default( + cat_tensor, *args[1:], **kwargs + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_wait_tensor(args, kwargs): + return torch.ops._c10d_functional.wait_tensor(args[0]) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_to_copy(args, kwargs): + res_shards_list = [ + aten._to_copy.default(shard, *args[1:], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_view(args, kwargs): + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_equal(args, kwargs): + """ + LocalShardsWrapper equal impl also checks for equality of storage metadata + and the order of shards + """ + a, b = args[0], args[1] + if len(a.local_shards()) != len(b.local_shards()): + return False + if not all( + aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) + ): + return False + if not a.storage_metadata() == b.storage_metadata(): + return False + return True + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_detach(args, kwargs): + self_ls = args[0] + deatched_local_shards = [ + aten.detach.default(shard) for shard in self_ls.local_shards() + ] + self_ls._local_shards = deatched_local_shards + self_ls._storage_meta.properties.requires_grad = False + return self_ls + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_clone(args, kwargs): + self_ls = args[0] + desired_memory_format = kwargs.get("memory_format", None) + if desired_memory_format and desired_memory_format != torch.preserve_format: + raise NotImplementedError( + f"{desired_memory_format} is not supported for LocalShardsWrapper!" + ) + cloned_local_shards = [ + shard.clone(memory_format=desired_memory_format) + for shard in self_ls._local_shards + ] + return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_new_empty(args, kwargs): + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + + @property + def device(self) -> torch._C.device: # type: ignore[override] + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) + + @property + def is_meta(self) -> bool: # type: ignore[override] + return self._local_shards[0].is_meta if self._local_shards else True + + # pyre-ignore[14] + def is_pinned(self) -> bool: # type: ignore[override] + return self._storage_meta.properties.pin_memory + + # pyre-ignore[14] + def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": + self._storage_meta.properties.requires_grad = requires_grad + [shard.requires_grad_(requires_grad) for shard in self._local_shards] + return self + + def local_shards(self) -> List[torch.Tensor]: + """ + Returns a list of :class:`torch.Tensor' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + def local_sizes(self) -> List[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local sizes for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.sizes for chunk in self._storage_meta.chunks] + + def local_offsets(self) -> List[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local offsets for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.offsets for chunk in self._storage_meta.chunks] + + @property + def local_chunks(self) -> List[ChunkStorageMetadata]: + """ + Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the + metadata for each tensor shard + """ + return self._storage_meta.chunks + + def storage_metadata(self) -> TensorStorageMetadata: + """ + Returns a :class:`TensorStorageMetadata` object corresponding to the + metadata for the local tensor on current rank + """ + return self._storage_meta + + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__( + self, fqn: str, object: Any # pyre-ignore[2] + ) -> List[WriteItem]: + """ + For compatibility with DCP, we support creation of WriteItems + such that they can be saved properly. + """ + return [ + WriteItem( + index=MetadataIndex(fqn, chunks.offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=chunks.offsets, + sizes=chunks.sizes, + ), + properties=self._storage_meta.properties, + size=object.size(), + ), + ) + for tensor, chunks in zip(self.local_shards(), self.local_chunks) + ] + + def __create_chunk_list__(self) -> List[ChunkStorageMetadata]: + """ + For compatibility with DCP, we support creation of chunk lists + such that they can be saved properly. + """ + return self._storage_meta.chunks + + def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: + """ + For compatibility with DCP, we support finding shard based on index + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + # Fast lookup path + if index.index is not None: + if ( + len(self._local_shards) > index.index + and self._storage_meta.chunks[index.index].offsets == index.offset + ): + return self._local_shards[index.index] + + if index.offset is not None: + for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): + if chunk.offsets == index.offset: + return shard + + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + + def _get_tensor_size_bytes(self) -> int: + object_size = 0 + for shard in self.local_shards(): + object_size += shard.nelement() * shard.element_size() + return object_size + + # pyre-fixme[3]: Return type must be annotated. + def __hash__(self): + return id(self) + + # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" + + def __str__(self) -> str: + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/torchrec/distributed/tensor_pool.py b/torchrec/distributed/tensor_pool.py new file mode 100644 index 000000000..436851b9d --- /dev/null +++ b/torchrec/distributed/tensor_pool.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Optional, Tuple, Type, Union + +import torch +from torchrec.distributed.object_pool import ShardedObjectPool +from torchrec.distributed.sharding.rw_pool_sharding import ( + InferRwObjectPoolInputDist, + RwObjectPoolIDsDist, +) +from torchrec.distributed.sharding.rw_tensor_pool_sharding import ( + InferRwTensorPoolOutputDist, + InferRwTensorPoolSharding, + RwTensorPoolValuesDist, + TensorPoolRwSharding, +) +from torchrec.distributed.tensor_sharding import ObjectPoolShardingContext +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ModuleSharder, + ObjectPoolShardingPlan, + ObjectPoolShardingType, + ShardingEnv, +) +from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup +from torchrec.modules.tensor_pool import TensorPool +from torchrec.modules.utils import deterministic_dedup + + +class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]): + def __init__( + self, + awaitable: Awaitable[torch.Tensor], + unbucketize_permute: torch.Tensor, + ) -> None: + super().__init__() + self._awaitable = awaitable + self._unbucketize_permute = unbucketize_permute + + def _wait_impl(self) -> torch.Tensor: + tensor = self._awaitable.wait() + + return tensor[self._unbucketize_permute] + + +class ShardedTensorPool( + ShardedObjectPool[torch.Tensor, torch.Tensor, ObjectPoolShardingContext] +): + """ + Sharded implementation of `TensorPool` + + When dealing with large pool of tensors that cannot fit in a single device memory + (i.e. HBM / UVM / CPU etc), this module handles sharding the pool row-wise, including + orchestrating the communication between ranks for distributed lookup and update. + + Args: + env (ShardingEnv): sharding environment (e.g. world_size, ranks, etc) + pool_size (int): total number of rows of tensors in the pool + dim (int): dimension that each tensor in the pool + dtype (torch.dtype): dtype of the tensors in the pool + sharding_plan (ObjectPoolShardingPlan): info about sharding strategy + device (Optional[torch.device]): default device + enable_uvm (bool): if set to true, the pool will be allocated on UVM + + Example:: + # Example on 2 GPUs + + # rank 0 + sharded_keyed_jagged_tensor_pool.update( + ids=torch.Tensor([2,0],dtype=torch.int,device="cuda:0") + values=torch.Tensor([ + [1,2,3], + [4,5,6], + ],dtype=torch.int,device="cuda:0") + ) + + # on rank 1 + sharded_keyed_jagged_tensor_pool.update( + ids=torch.Tensor([1,3],dtype=torch.int,device="cuda:1") + values=torch.Tensor([ + [7,8,9], + [10,11,12], + ],dtype=torch.int,device="cuda:1") + ) + + # At this point the global state is: + # ids tensor + # 0 [1,2,3] <- rank 0 + # 1 [7,8,9] <- rank 1 + # 2 [4,5,6] <- rank 0 + # 3 [10,11,12] <- rank 1 + + """ + + def __init__( + self, + env: ShardingEnv, + pool_size: int, + dim: int, + dtype: torch.dtype, + sharding_plan: ObjectPoolShardingPlan, + device: Optional[torch.device] = None, + enable_uvm: bool = False, + ) -> None: + super().__init__() + + # pyre-fixme[4]: Attribute must be annotated. + self._world_size = env.world_size + # pyre-fixme[4]: Attribute must be annotated. + self._rank = env.rank + self._pool_size = pool_size + self._sharding_env = env + self._dim: int = dim + # pyre-fixme[4]: Attribute must be annotated. + self._device = device if device is not None else torch.device("meta") + self._dtype = dtype + self._sharding_plan = sharding_plan + self._enable_uvm = enable_uvm + + if sharding_plan.sharding_type == ObjectPoolShardingType.ROW_WISE: + self._sharding: TensorPoolRwSharding = TensorPoolRwSharding( + env=self._sharding_env, + device=self._device, + pool_size=self._pool_size, + dim=dim, + ) + else: + raise NotImplementedError( + f"Sharding type {self._sharding_plan.sharding_type} is not implemented" + ) + + self._lookup: TensorPoolLookup = TensorLookup( + self._sharding.local_pool_size, + self._dim, + self._dtype, + self._device, + self._enable_uvm, + ) + + self._lookup_ids_dist_impl: RwObjectPoolIDsDist = ( + self._sharding.create_lookup_ids_dist() + ) + self._lookup_values_dist_impl: RwTensorPoolValuesDist = ( + self._sharding.create_lookup_values_dist() + ) + + self._update_ids_dist_impl: RwObjectPoolIDsDist = ( + self._sharding.create_update_ids_dist() + ) + self._update_values_dist_impl: RwTensorPoolValuesDist = ( + self._sharding.create_update_values_dist() + ) + + self._initialize_torch_state() + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def dim(self) -> int: + return self._dim + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def device(self) -> torch.device: + torch._assert(self._device is not None, "self._device should already be set") + return self._device + + def _update_preproc(self, values: torch.Tensor) -> torch.Tensor: + assert values.dtype == self.dtype + assert values.size(1) == self._dim + assert values.device.type == self._device.type + return values + + def _update_ids_dist( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: + return self._update_ids_dist_impl(ctx=ctx, ids=ids) + + def _update_values_dist( + self, ctx: ObjectPoolShardingContext, values: torch.Tensor + ) -> LazyAwaitable[torch.Tensor]: + return self._update_values_dist_impl(ctx=ctx, values=values) + + def _update_local( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor, values: torch.Tensor + ) -> None: + deduped_ids, dedup_permutation = deterministic_dedup(ids) + + self._lookup.update( + deduped_ids, + values[dedup_permutation], + ) + + def _lookup_ids_dist( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> Awaitable[Awaitable[torch.Tensor]]: + return self._lookup_ids_dist_impl(ctx=ctx, ids=ids) + + def _lookup_local( + self, ctx: ObjectPoolShardingContext, ids: torch.Tensor + ) -> torch.Tensor: + return self._lookup.lookup(ids) + + def _lookup_values_dist( + self, + ctx: ObjectPoolShardingContext, + values: torch.Tensor, + ) -> LazyAwaitable[torch.Tensor]: + return TensorPoolAwaitable( + awaitable=self._lookup_values_dist_impl(ctx, values), + unbucketize_permute=ctx.unbucketize_permute, + ) + + def create_context(self) -> ObjectPoolShardingContext: + return self._sharding.create_context() + + def _initialize_torch_state(self) -> None: + for fqn, tensor in self._sharding.get_sharded_states_to_register(self._lookup): + self.register_buffer(fqn, tensor) + + +@torch.fx.wrap +def update( + shard: torch.nn.Parameter, rank_ids: torch.Tensor, values: torch.Tensor +) -> torch.Tensor: + if values.device != shard.device: + values = values.to(shard.device) + shard[rank_ids] = values + return torch.empty(0) + + +class LocalShardPool(torch.nn.Module): + """ + Module containing a single shard of a tensor pool as a parameter. + + Used to lookup and update the pool during inference. + + Args: + shard (torch.Tensor): Subset of the tensor pool. + + Example: + # shard containing 2 rows from tensor pool with dim=3 + shard = torch.tensor([ + [1,2,3], + [4,5,6], + ]) + pool = LocalShardPool(shard) + out = pool(torch.tensor([0])) + # out is tensor([1,2,3]) i.e. first row of the shard + """ + + def __init__( + self, + shard: torch.Tensor, + ) -> None: + super().__init__() + self._shard: torch.nn.Parameter = torch.nn.Parameter( + shard, + requires_grad=False, + ) + + def forward(self, rank_ids: torch.Tensor) -> torch.Tensor: + """ + Lookup the rows in the shard corresponding to the given rank ids. + + Args: + rank_ids (torch.Tensor): Tensor of rank ids to lookup. + + Returns: + torch.Tensor: Tensor of values corresponding to the given rank ids. + """ + return self._shard[rank_ids] + + def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None: + _ = update(self._shard, rank_ids, values) + + +class ShardedInferenceTensorPool( + ShardedObjectPool[torch.Tensor, List[torch.Tensor], ObjectPoolShardingContext], +): + _local_shard_pools: torch.nn.ModuleList + _world_size: int + _device: torch.device + _rank: int + + def __init__( + self, + env: ShardingEnv, + pool_size: int, + dim: int, + dtype: torch.dtype, + plan: ObjectPoolShardingPlan, + module: TensorPool, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + self._pool_size = pool_size + self._dtype = dtype + self._sharding_env = env + self._world_size = env.world_size + self._device = device or torch.device("cuda") + self._sharding_plan = plan + + self._rank = env.rank + self._dim = dim + + torch._assert( + self._sharding_plan.inference, "Plan needs to have inference enabled" + ) + + if self._sharding_plan.sharding_type == ObjectPoolShardingType.ROW_WISE: + # pyre-fixme[4]: Attribute must be annotated. + self._sharding = InferRwTensorPoolSharding( + env=self._sharding_env, + device=self._device, + pool_size=self._pool_size, + ) + else: + raise NotImplementedError( + f"Sharding type {self._sharding_plan.sharding_type} is not implemented" + ) + + self._local_shard_pools: torch.nn.ModuleList = torch.nn.ModuleList() + offset = 0 + for rank, this_rank_size in zip( + range( + self._world_size, + ), + self._sharding.local_pool_size_per_rank, + ): + shard_device = ( + torch.device("cpu") + if device == torch.device("cpu") + else torch.device("cuda", rank) + ) + self._local_shard_pools.append( + LocalShardPool( + torch.empty( + ( + this_rank_size, + self._dim, + ), + dtype=self._dtype, + device=shard_device, + requires_grad=False, + ), + ) + ) + + if module._pool.device != torch.device("meta"): + local_shard = module._pool[offset : offset + this_rank_size] + self._local_shard_pools[rank]._shard.copy_(local_shard) + + offset += this_rank_size + + self._lookup_ids_dist_impl: InferRwObjectPoolInputDist = ( + self._sharding.create_lookup_ids_dist() + ) + self._lookup_values_dist_impl: InferRwTensorPoolOutputDist = ( + self._sharding.create_lookup_values_dist() + ) + + # TODO use DTensor that works with Inference Publishing. Right now ShardedTensor doesn't fit this shoe. + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def dim(self) -> int: + return self._dim + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def device(self) -> torch.device: + torch._assert(self._device is not None, "self._device should already be set") + return self._device + + def create_context(self) -> ObjectPoolShardingContext: + raise NotImplementedError("create_context() is not implemented") + + # pyre-ignore + def _lookup_ids_dist( + self, + ids: torch.Tensor, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: + return self._lookup_ids_dist_impl(ids) + + # pyre-ignore + def _update_ids_dist( + self, + ids: torch.Tensor, + values: torch.Tensor, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + return self._lookup_ids_dist_impl.update(ids, values) + + # pyre-ignore + def _lookup_local( + self, + dist_input: List[torch.Tensor], + ) -> List[torch.Tensor]: + ret = [] + for i, shard in enumerate(self._local_shard_pools): + ret.append(shard(dist_input[i])) + return ret + + # pyre-ignore + def _lookup_values_dist( + self, + lookups: List[torch.Tensor], + ) -> torch.Tensor: + return self._lookup_values_dist_impl(lookups) + + # pyre-ignore + def forward(self, ids: torch.Tensor) -> torch.Tensor: + dist_input, unbucketize_permute = self._lookup_ids_dist(ids) + lookup = self._lookup_local(dist_input) + + # Here we are playing a trick to workaround a fx tracing issue, + # as proxy is not iteratable. + lookup_list = [] + for i in range(self._world_size): + lookup_list.append(lookup[i]) + + output = self._lookup_values_dist(lookup_list) + + return output[unbucketize_permute].view(-1, self._dim) + + # pyre-ignore + def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor): + raise NotImplementedError("Inference does not support update") + + # pyre-ignore + def _update_local( + self, + dist_input: List[torch.Tensor], + dist_values: List[torch.Tensor], + ) -> None: + for i, shard in enumerate(self._local_shard_pools): + ids = dist_input[i] + values = dist_values[i] + deduped_ids, dedup_permutation = deterministic_dedup(ids) + shard.update(deduped_ids, values[dedup_permutation]) + + # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`. + def _update_preproc(self, values: torch.Tensor) -> torch.Tensor: + pass + + def update(self, ids: torch.Tensor, values: torch.Tensor) -> None: + dist_input, dist_values, unbucketize_permute = self._update_ids_dist( + ids, values + ) + self._update_local(dist_input, dist_values) + + +class TensorPoolSharder(ModuleSharder[TensorPool]): + def __init__(self) -> None: + super().__init__() + + def shard( + self, + module: TensorPool, + plan: ObjectPoolShardingPlan, + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> Union[ShardedTensorPool, ShardedInferenceTensorPool]: + if plan.inference: + return ShardedInferenceTensorPool( + env=env, + pool_size=module.pool_size, + dim=module.dim, + dtype=module.dtype, + plan=plan, + device=device, + module=module, + ) + return ShardedTensorPool( + env=env, + pool_size=module.pool_size, + dim=module.dim, + dtype=module.dtype, + sharding_plan=plan, + device=device, + enable_uvm=module._enable_uvm, + ) + + @property + def module_type(self) -> Type[TensorPool]: + return TensorPool diff --git a/torchrec/distributed/tensor_sharding.py b/torchrec/distributed/tensor_sharding.py new file mode 100644 index 000000000..6a7ca0715 --- /dev/null +++ b/torchrec/distributed/tensor_sharding.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import distributed as dist +from torchrec.distributed.types import Multistreamable, ShardingEnv + + +@dataclass +class ObjectPoolShardingContext(Multistreamable): + ids_before_input_dist: Optional[torch.Tensor] = None + num_ids_each_rank_to_receive: Optional[torch.Tensor] = None + num_ids_each_rank_to_send: Optional[torch.Tensor] = None + bucketize_permute: Optional[torch.Tensor] = None + unbucketize_permute: Optional[torch.Tensor] = None + + def record_stream(self, stream: torch.Stream) -> None: + if self.ids_before_input_dist is not None: + self.ids_before_input_dist.record_stream(stream) + if self.num_ids_each_rank_to_receive is not None: + self.num_ids_each_rank_to_receive.record_stream(stream) + if self.num_ids_each_rank_to_send is not None: + self.num_ids_each_rank_to_send.record_stream(stream) + if self.bucketize_permute is not None: + self.bucketize_permute.record_stream(stream) + if self.unbucketize_permute is not None: + self.unbucketize_permute.record_stream(stream) + + +@dataclass +class RwShardingContext(Multistreamable): + block_size: Optional[torch.Tensor] = None + + def record_stream(self, stream: torch.Stream) -> None: + if self.block_size is not None: + self.block_size.record_stream(stream) + + +@dataclass +class ObjectPoolRwShardingContext(ObjectPoolShardingContext, RwShardingContext): + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + + +@dataclass +class ObjectPoolReplicatedRwShardingContext(ObjectPoolRwShardingContext): + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + + +@dataclass +class TensorPoolRwShardingContext(ObjectPoolRwShardingContext): + """ + Placeholder for additional sharding context for TensorPool + """ + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + + +class ObjectPoolSharding(ABC): + @abstractmethod + def create_update_ids_dist(self) -> torch.nn.Module: + pass + + @abstractmethod + def create_update_values_dist(self) -> torch.nn.Module: + pass + + @abstractmethod + def create_lookup_ids_dist(self) -> torch.nn.Module: + pass + + @abstractmethod + def create_lookup_values_dist(self) -> torch.nn.Module: + pass + + @abstractmethod + def get_sharded_states_to_register(self) -> Iterable[Tuple[str, torch.Tensor]]: + pass + + @abstractmethod + def create_context(self) -> ObjectPoolShardingContext: + pass + + +class InferObjectPoolSharding(ABC): + def __init__( + self, + pool_size: int, + env: ShardingEnv, + device: torch.device, + ) -> None: + self._pool_size = pool_size + self._env = env + # pyre-ignore + self._pg: dist.ProcessGroup = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + + self._block_size: int = ( + pool_size + self._env.world_size - 1 + ) // self._env.world_size + self._last_block_size: int = self._pool_size - self._block_size * ( + self._world_size - 1 + ) + self.local_pool_size_per_rank: List[int] = [self._block_size] * ( + self._world_size - 1 + ) + [self._last_block_size] + + self._block_size_t: torch.Tensor = torch.tensor( + [self._block_size], device=self._device, dtype=torch.long + ) + + @abstractmethod + def create_lookup_ids_dist(self) -> torch.nn.Module: + pass + + @abstractmethod + def create_lookup_values_dist(self) -> torch.nn.Module: + pass diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py new file mode 100644 index 000000000..0c115f87f --- /dev/null +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -0,0 +1,1113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import copy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +import torchrec +from fbgemm_gpu import sparse_ops # noqa: F401, E402 +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + EmbeddingLocation, + PoolingMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from torch import nn, quantization as quant, Tensor +from torch.distributed._shard.sharding_spec import ShardingSpec +from torch.utils import _pytree as pytree +from torchrec import ( + EmbeddingCollection, + EmbeddingConfig, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ModuleSharder, + ShardingType, +) +from torchrec.distributed.fused_params import ( + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + FUSED_PARAM_REGISTER_TBE_BOOL, +) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.planner.types import ParameterConstraints +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollectionSharder, + ShardedQuantEmbeddingCollection, +) +from torchrec.distributed.quant_embeddingbag import ( + QuantEmbeddingBagCollection, + QuantEmbeddingBagCollectionSharder, + QuantFeatureProcessedEmbeddingBagCollectionSharder, + ShardedQuantEmbeddingBagCollection, + ShardedQuantFeatureProcessedEmbeddingBagCollection, +) +from torchrec.distributed.quant_state import WeightSpec +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN +from torchrec.distributed.types import ( + ModuleShardingPlan, + ParameterSharding, + ShardingEnv, + ShardingPlan, +) +from torchrec.distributed.utils import CopyableMixin +from torchrec.inference.modules import set_pruning_data +from torchrec.modules.embedding_configs import ( + data_type_to_sparse_type, + dtype_to_data_type, + EmbeddingBagConfig, + QuantConfig, +) +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.quant.embedding_modules import ( + EmbeddingCollection as QuantEmbeddingCollection, + FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + MODULE_ATTR_REGISTER_TBES_BOOL, + quant_prep_enable_quant_state_dict_split_scale_bias_for_types, + quant_prep_enable_register_tbes, + QuantManagedCollisionEmbeddingCollection, +) + + +@dataclass +class TestModelInfo: + sparse_device: torch.device + dense_device: torch.device + num_features: int + num_float_features: int + num_weighted_features: int + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]] = field( + default_factory=list + ) + weighted_tables: List[EmbeddingBagConfig] = field(default_factory=list) + model: torch.nn.Module = torch.nn.Module() + quant_model: torch.nn.Module = torch.nn.Module() + sharders: List[ModuleSharder] = field(default_factory=list) + topology: Optional[Topology] = None + planner: Optional[EmbeddingShardingPlanner] = None + + +class KJTInputExportWrapper(torch.nn.Module): + def __init__( + self, + module_kjt_input: torch.nn.Module, + kjt_keys: List[str], + ) -> None: + super().__init__() + self._module_kjt_input = module_kjt_input + self._kjt_keys = kjt_keys + + # pyre-ignore + def forward( + self, + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + # pyre-ignore + *args, + # pyre-ignore + **kwargs, + ): + kjt = KeyedJaggedTensor( + keys=self._kjt_keys, + values=values, + lengths=lengths, + weights=weights, + ) + output = self._module_kjt_input(kjt, *args, **kwargs) + # TODO(ivankobzarev): Support of None leaves in dynamo/export (e.g. KJT offsets) + return [leaf for leaf in pytree.tree_leaves(output) if leaf is not None] + + +class KJTInputExportDynamicShapeWrapper(torch.nn.Module): + def __init__( + self, + kjt_input_wrapper: KJTInputExportWrapper, + ) -> None: + super().__init__() + self.kjt_input_wrapper = kjt_input_wrapper + + # pyre-ignore + def forward( + self, + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + # pyre-ignore + *args, + # pyre-ignore + **kwargs, + ): + # Generate unbacked symints to represent sizes + # for values and weights, constrain them reasonably + values_size = values[0].item() + torch._check_is_size(values_size) + torch._check(values_size >= lengths.shape[0]) + # pyre-ignore + values = torch.ones(values_size).to(values.device) + if weights is not None: + weights_size = weights.int()[0].item() + torch._check_is_size(weights_size) + torch._check(weights_size >= lengths.shape[0]) + # pyre-ignore + weights = torch.ones(weights_size).to(weights.device) + + return self.kjt_input_wrapper(values, lengths, weights, *args, **kwargs) + + +def prep_inputs( + model_info: TestModelInfo, + world_size: int, + batch_size: int = 1, + count: int = 5, + long_indices: bool = True, +) -> List[ModelInput]: + inputs = [] + if long_indices: + indices_dtype = torch.int64 + lengths_dtype = torch.int64 + else: + indices_dtype = torch.int32 + lengths_dtype = torch.int32 + for _ in range(count): + inputs.append( + ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=model_info.num_float_features, + tables=model_info.tables, + weighted_tables=model_info.weighted_tables, + indices_dtype=indices_dtype, + lengths_dtype=lengths_dtype, + )[1][0], + ) + + return inputs + + +class KJTInputExportWrapperWithStrides(torch.nn.Module): + """ + Version of KJTInputExportWrapper with stride_per_key_per_rank_tensor argument for VB path. + """ + + def __init__( + self, + module_kjt_input: torch.nn.Module, + kjt_keys: List[str], + ) -> None: + super().__init__() + self._module_kjt_input = module_kjt_input + self._kjt_keys = kjt_keys + + # pyre-ignore + def forward( + self, + values: torch.Tensor, + lengths: torch.Tensor, + stride_per_key_per_rank: Optional[List[List[int]]], + # pyre-ignore + *args, + # pyre-ignore + **kwargs, + ): + kjt = KeyedJaggedTensor( + keys=self._kjt_keys, + values=values, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + output = self._module_kjt_input(kjt, *args, **kwargs) + return [leaf for leaf in pytree.tree_leaves(output) if leaf is not None] + + +def prep_inputs_multiprocess( + model_info: TestModelInfo, world_size: int, batch_size: int = 1, count: int = 5 +) -> List[Tuple[ModelInput, List[ModelInput]]]: + inputs = [] + for _ in range(count): + inputs.append( + ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=model_info.num_float_features, + tables=model_info.tables, + weighted_tables=model_info.weighted_tables, + ) + ) + return inputs + + +def model_input_to_forward_args_kjt( + mi: ModelInput, +) -> Tuple[ + List[str], + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + kjt = mi.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) + return ( + kjt._keys, + kjt._values, + kjt._weights, + kjt._lengths, + kjt._offsets, + ) + + +# We want to be torch types bound, args for TorchTypesModelInputWrapper +def model_input_to_forward_args( + mi: ModelInput, +) -> Tuple[ + torch.Tensor, + List[str], + torch.Tensor, + List[str], + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + idlist_kjt = mi.idlist_features + idscore_kjt = mi.idscore_features + assert isinstance(idlist_kjt, KeyedJaggedTensor) + assert isinstance(idscore_kjt, KeyedJaggedTensor) + return ( + mi.float_features, + idlist_kjt._keys, + idlist_kjt._values, + idscore_kjt._keys, + idscore_kjt._values, + idscore_kjt._weights, + mi.label, + idlist_kjt._lengths, + idlist_kjt._offsets, + idscore_kjt._lengths, + idscore_kjt._offsets, + ) + + +def create_cw_min_partition_constraints( + table_min_partition_pairs: List[Tuple[str, int]] +) -> Dict[str, ParameterConstraints]: + return { + name: ParameterConstraints( + sharding_types=[ShardingType.COLUMN_WISE.value], + min_partition=min_partition, + ) + for name, min_partition in table_min_partition_pairs + } + + +def quantize( + module: torch.nn.Module, + inplace: bool, + output_type: torch.dtype = torch.float, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + weight_dtype: torch.dtype = torch.qint8, + per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None, +) -> torch.nn.Module: + module_types: List[Type[torch.nn.Module]] = [ + torchrec.modules.embedding_modules.EmbeddingBagCollection, + torchrec.modules.embedding_modules.EmbeddingCollection, + torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection, + ] + if register_tbes: + quant_prep_enable_register_tbes(module, module_types) + if quant_state_dict_split_scale_bias: + quant_prep_enable_quant_state_dict_split_scale_bias_for_types( + module, module_types + ) + + qconfig = quant.QConfig( + activation=quant.PlaceholderObserver.with_args(dtype=output_type), + weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype), + ) + + if per_table_weight_dtypes: + qconfig = QuantConfig( + activation=quant.PlaceholderObserver.with_args(dtype=output_type), + weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8), + per_table_weight_dtype=per_table_weight_dtypes, + ) + + return quant.quantize_dynamic( + module, + qconfig_spec={ + EmbeddingBagCollection: qconfig, + EmbeddingCollection: qconfig, + ManagedCollisionEmbeddingCollection: qconfig, + }, + mapping={ + EmbeddingBagCollection: QuantEmbeddingBagCollection, + EmbeddingCollection: QuantEmbeddingCollection, + ManagedCollisionEmbeddingCollection: QuantManagedCollisionEmbeddingCollection, + }, + inplace=inplace, + ) + + +def quantize_fpebc( + module: torch.nn.Module, + inplace: bool, + output_type: torch.dtype = torch.float, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + weight_dtype: torch.dtype = torch.qint8, + per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None, +) -> torch.nn.Module: + module_types: List[Type[torch.nn.Module]] = [ + torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection, + ] + if register_tbes: + quant_prep_enable_register_tbes(module, module_types) + if quant_state_dict_split_scale_bias: + quant_prep_enable_quant_state_dict_split_scale_bias_for_types( + module, module_types + ) + + qconfig = quant.QConfig( + activation=quant.PlaceholderObserver.with_args(dtype=output_type), + weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype), + ) + + if per_table_weight_dtypes: + qconfig = QuantConfig( + activation=quant.PlaceholderObserver.with_args(dtype=output_type), + weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8), + per_table_weight_dtype=per_table_weight_dtypes, + ) + + return quant.quantize_dynamic( + module, + qconfig_spec={ + FeatureProcessedEmbeddingBagCollection: qconfig, + }, + mapping={ + FeatureProcessedEmbeddingBagCollection: QuantFeatureProcessedEmbeddingBagCollection, + }, + inplace=inplace, + ) + + +class TestQuantFPEBCSharder(QuantFeatureProcessedEmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + shardable_params: Optional[List[str]] = None, + ) -> None: + super().__init__(fused_params=fused_params, shardable_params=shardable_params) + self._sharding_type = sharding_type + self._kernel_type = kernel_type + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + def shard( + self, + module: QuantFeatureProcessedEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantFeatureProcessedEmbeddingBagCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, MODULE_ATTR_REGISTER_TBES_BOOL, False + ) + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + return ShardedQuantFeatureProcessedEmbeddingBagCollection( + module=module, + table_name_to_parameter_sharding=params, + env=env, + fused_params=fused_params, + device=device, + feature_processor=module.feature_processor, + ) + + +class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + shardable_params: Optional[List[str]] = None, + ) -> None: + super().__init__(fused_params=fused_params, shardable_params=shardable_params) + self._sharding_type = sharding_type + self._kernel_type = kernel_type + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + def shard( + self, + module: QuantEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantEmbeddingBagCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, MODULE_ATTR_REGISTER_TBES_BOOL, False + ) + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + return ShardedQuantEmbeddingBagCollection( + module=module, + table_name_to_parameter_sharding=params, + env=env, + fused_params=fused_params, + device=device, + ) + + +class TestQuantECSharder(QuantEmbeddingCollectionSharder): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + shardable_params: Optional[List[str]] = None, + ) -> None: + super().__init__(fused_params=fused_params, shardable_params=shardable_params) + self._sharding_type = sharding_type + self._kernel_type = kernel_type + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + def shard( + self, + module: QuantEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: Union[Dict[str, ShardingEnv], ShardingEnv], + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantEmbeddingCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, MODULE_ATTR_REGISTER_TBES_BOOL, False + ) + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ) + return ShardedQuantEmbeddingCollection( + module, params, env, fused_params, device + ) + + +class KJTInputWrapper(torch.nn.Module): + def __init__( + self, + module_kjt_input: torch.nn.Module, + ) -> None: + super().__init__() + self._module_kjt_input = module_kjt_input + self.add_module("_module_kjt_input", self._module_kjt_input) + + # pyre-ignore + def forward( + self, + keys: List[str], + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ): + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + offsets=offsets, + ) + return self._module_kjt_input(kjt) + + +# Wrapper for module that accepts ModelInput to avoid jit scripting of ModelInput (dataclass) and be fully torch types bound. +class TorchTypesModelInputWrapper(CopyableMixin): + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward( + self, + float_features: torch.Tensor, + idlist_features_keys: List[str], + idlist_features_values: torch.Tensor, + idscore_features_keys: List[str], + idscore_features_values: torch.Tensor, + idscore_features_weights: torch.Tensor, + label: torch.Tensor, + idlist_features_lengths: Optional[torch.Tensor] = None, + idlist_features_offsets: Optional[torch.Tensor] = None, + idscore_features_lengths: Optional[torch.Tensor] = None, + idscore_features_offsets: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + idlist_kjt = KeyedJaggedTensor( + keys=idlist_features_keys, + values=idlist_features_values, + lengths=idlist_features_lengths, + offsets=idlist_features_offsets, + ) + idscore_kjt = KeyedJaggedTensor( + keys=idscore_features_keys, + values=idscore_features_values, + weights=idscore_features_weights, + lengths=idscore_features_lengths, + offsets=idscore_features_offsets, + ) + mi = ModelInput( + float_features=float_features, + idlist_features=idlist_kjt, + idscore_features=idscore_kjt, + label=label, + ) + return self._module(mi) + + +def create_test_model( + num_embeddings: int, + emb_dim: int, + world_size: int, + batch_size: int, + dense_device: torch.device, + sparse_device: torch.device, + quant_state_dict_split_scale_bias: bool = False, + num_features: int = 1, + num_float_features: int = 8, + num_weighted_features: int = 1, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + weight_dtype: torch.dtype = torch.qint8, + pruning_dict: Optional[Dict[str, int]] = None, +) -> TestModelInfo: + topology: Topology = Topology( + world_size=world_size, compute_device=sparse_device.type + ) + mi = TestModelInfo( + dense_device=dense_device, + sparse_device=sparse_device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + constraints=constraints, + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(mi.num_features) + ] + + mi.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(mi.num_weighted_features) + ] + + if pruning_dict: + for config in mi.tables + mi.weighted_tables: + if config.name in pruning_dict: + config.num_embeddings_post_pruning = pruning_dict[config.name] + + mi.model = TorchTypesModelInputWrapper( + TestSparseNN( + # pyre-ignore [6] + tables=mi.tables, + weighted_tables=mi.weighted_tables, + num_float_features=mi.num_float_features, + dense_device=dense_device, + sparse_device=sparse_device, + ) + ) + mi.model.training = False + + if pruning_dict: + set_pruning_data(mi.model, pruning_dict) + + mi.quant_model = quantize( + module=mi.model, + inplace=False, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + weight_dtype=weight_dtype, + ) + return mi + + +def create_test_model_ebc_only_no_quantize( + num_embeddings: int, + emb_dim: int, + world_size: int, + batch_size: int, + dense_device: torch.device, + sparse_device: torch.device, + num_features: int = 1, + num_float_features: int = 8, + num_weighted_features: int = 1, + compute_device: str = "cuda", + feature_processor: bool = False, +) -> TestModelInfo: + topology: Topology = Topology(world_size=world_size, compute_device=compute_device) + mi = TestModelInfo( + dense_device=dense_device, + sparse_device=sparse_device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(mi.num_features) + ] + + mi.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(mi.num_weighted_features) + ] + + if feature_processor: + max_feature_lengths = {config.feature_names[0]: 100 for config in mi.tables} + fp = PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, device=mi.sparse_device + ) + ebc = FeatureProcessedEmbeddingBagCollection( + embedding_bag_collection=EmbeddingBagCollection( + # pyre-ignore [6] + tables=mi.tables, + device=mi.sparse_device, + is_weighted=True, + ), + feature_processors=fp, + ) + else: + ebc = EmbeddingBagCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + + mi.model = torch.nn.Sequential(ebc) + mi.model.training = False + return mi + + +def create_test_model_ebc_only( + num_embeddings: int, + emb_dim: int, + world_size: int, + batch_size: int, + dense_device: torch.device, + sparse_device: torch.device, + num_features: int = 1, + num_float_features: int = 8, + num_weighted_features: int = 1, + quant_state_dict_split_scale_bias: bool = False, + compute_device: str = "cuda", + feature_processor: bool = False, +) -> TestModelInfo: + mi = create_test_model_ebc_only_no_quantize( + num_embeddings=num_embeddings, + emb_dim=emb_dim, + world_size=world_size, + batch_size=batch_size, + dense_device=dense_device, + sparse_device=sparse_device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + compute_device=compute_device, + feature_processor=feature_processor, + ) + + if feature_processor: + mi.quant_model = quantize_fpebc( + module=mi.model, + inplace=True, + register_tbes=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) + else: + mi.quant_model = quantize( + module=mi.model, + inplace=False, + register_tbes=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) + return mi + + +def shard_qebc( + mi: TestModelInfo, + sharding_type: ShardingType, + device: torch.device, + expected_shards: Optional[List[List[Tuple[Tuple[int, int, int, int], str]]]] = None, + plan: Optional[ShardingPlan] = None, + ebc_fqn: str = "_module.sparse.ebc", + shard_score_ebc: bool = False, + feature_processor: bool = False, +) -> torch.nn.Module: + if feature_processor: + sharder = TestQuantFPEBCSharder( + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=( + [table.name for table in mi.tables] + + ([table.name for table in mi.weighted_tables]) + ), + ) + else: + sharder = TestQuantEBCSharder( + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=( + [table.name for table in mi.tables] + + ( + [table.name for table in mi.weighted_tables] + if shard_score_ebc + else [] + ) + ), + ) + if not plan: + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + if expected_shards is not None: + msp = plan.plan[ebc_fqn] + for i in range(mi.num_features): + ps: ParameterSharding = msp[f"table_{i}"] + assert ps.sharding_type == sharding_type.value + assert ps.sharding_spec is not None + sharding_spec: ShardingSpec = ps.sharding_spec + # pyre-ignore + assert len(sharding_spec.shards) == len(expected_shards[i]) + for shard, ((offset_r, offset_c, size_r, size_c), placement) in zip( + sharding_spec.shards, expected_shards[i] + ): + assert shard.shard_offsets == [offset_r, offset_c] + assert shard.shard_sizes == [size_r, size_c] + assert str(shard.placement) == placement + + # We want to leave quant_model unchanged to compare the results with it + quant_model_copy = copy.deepcopy(mi.quant_model) + sharded_model = _shard_modules( + module=quant_model_copy, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got `List[TestQuantEBCSharder]`. + sharders=[sharder], + device=device, + plan=plan, + # pyre-ignore + env=ShardingEnv.from_local(world_size=mi.topology.world_size, rank=0), + ) + return sharded_model + + +def shard_qec( + mi: TestModelInfo, + sharding_type: ShardingType, + device: torch.device, + expected_shards: Optional[List[List[Tuple[Tuple[int, int, int, int], str]]]], + plan: Optional[ShardingPlan] = None, +) -> torch.nn.Module: + sharder = TestQuantECSharder( + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + ) + + if not plan: + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + if expected_shards is not None: + msp: ModuleShardingPlan = plan.plan["_module_kjt_input.0"] # TODO: hardcoded + for i in range(mi.num_features): + # pyre-ignore + ps: ParameterSharding = msp[f"table_{i}"] + assert ps.sharding_type == sharding_type.value + assert ps.sharding_spec is not None + sharding_spec: ShardingSpec = ps.sharding_spec + # pyre-ignore + assert len(sharding_spec.shards) == len(expected_shards[i]) + for shard, ((offset_r, offset_c, size_r, size_c), placement) in zip( + sharding_spec.shards, expected_shards[i] + ): + assert shard.shard_offsets == [offset_r, offset_c] + assert shard.shard_sizes == [size_r, size_c] + assert str(shard.placement) == placement + + # We want to leave quant_model unchanged to compare the results with it + quant_model_copy = copy.deepcopy(mi.quant_model) + sharded_model = _shard_modules( + module=quant_model_copy, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got `List[TestQuantECSharder]`. + sharders=[sharder], + device=device, + plan=plan, + # pyre-ignore + env=ShardingEnv.from_local(world_size=mi.topology.world_size, rank=0), + ) + return sharded_model + + +# pyre-ignore +def assert_close(expected, actual) -> None: + if isinstance(expected, KeyedTensor): + assert isinstance(actual, KeyedTensor) + assert len(expected.keys()) == len(actual.keys()) + torch.testing.assert_close(expected.values(), actual.values()) + torch.testing.assert_close(expected.length_per_key(), actual.length_per_key()) + elif isinstance(expected, dict): + assert sorted(expected.keys()) == sorted(actual.keys()) + for feature, jt_e in expected.items(): + jt_got = actual[feature] + if isinstance(jt_e, torch.Tensor) and isinstance(jt_got, torch.Tensor): + if jt_got.device != jt_e.device: + jt_got = actual.to(jt_e.device) + assert_close(jt_e, jt_got) + else: + assert_close(jt_e.lengths(), jt_got.lengths()) + assert_close(jt_e.values(), jt_got.values()) + assert_close(jt_e.offsets(), jt_got.offsets()) + else: + if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor): + if actual.device != expected.device: + actual = actual.to(expected.device) + + torch.testing.assert_close(expected, actual) + + +def assert_weight_spec( + weights_spec: Dict[str, WeightSpec], + all_expected_shards: List[List[Tuple[Tuple[int, int, int, int], str]]], + ebc_fqn: str, + weights_prefix: str, + all_table_names: List[str], + sharding_type: str, +) -> None: + tbe_table_idxs = [0, 0] + for table_name, expected_shards in zip(all_table_names, all_expected_shards): + unsharded_weight_fqn = f"{ebc_fqn}.{weights_prefix}.{table_name}.weight" + for (offset_r, offset_c, size_r, size_c), placement in expected_shards: + tbe_idx: int = 0 + # Assumption of only one TBE per rank + if "rank:1" in placement: + tbe_idx = 1 + sharded_weight_fqn: str = ( + f"{ebc_fqn}.tbes.{tbe_idx}.{tbe_table_idxs[tbe_idx]}.{table_name}.weight" + ) + tbe_table_idxs[tbe_idx] += 1 + assert sharded_weight_fqn in weights_spec + wspec = weights_spec[sharded_weight_fqn] + assert wspec.fqn == unsharded_weight_fqn + assert wspec.shard_sizes == [size_r, size_c] + assert wspec.shard_offsets == [offset_r, offset_c] + assert wspec.sharding_type == sharding_type + + for qcomp in ["qscale", "qbias"]: + sharded_weight_qcomp_fqn: str = f"{sharded_weight_fqn}_{qcomp}" + assert sharded_weight_qcomp_fqn in weights_spec + wqcomp_spec = weights_spec[sharded_weight_qcomp_fqn] + assert wqcomp_spec.fqn == f"{unsharded_weight_fqn}_{qcomp}" + assert wqcomp_spec.shard_sizes == [size_r, 2] + assert wqcomp_spec.shard_offsets == [offset_r, 0] + assert wqcomp_spec.sharding_type == sharding_type + + +class MockTBE(nn.Module): + def __init__( + self, + embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], + device: torch.device, + output_dtype: int, + pooling_mode: PoolingMode, + ) -> None: + super(MockTBE, self).__init__() + self.embedding_specs: List[ + Tuple[str, int, int, SparseType, EmbeddingLocation] + ] = embedding_specs + self.pooling_mode = pooling_mode + self.device = device + self.output_dtype: torch.dtype = SparseType.from_int(output_dtype).as_dtype() + self.D: int = max([D for _, _, D, _, _ in embedding_specs]) + + self.weights: List[torch.Tensor] = [ + torch.arange(N).view(N, 1).expand(N, D) for _, N, D, _, _ in embedding_specs + ] + self.split_embedding_weights: List[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = [ + ( + torch.zeros(N, D, dtype=torch.uint8), + torch.zeros(N, 2, dtype=torch.uint8), + torch.zeros(N, 2, dtype=torch.uint8), + ) + for _, N, D, _, _ in embedding_specs + ] + + def forward( + self, + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + if self.pooling_mode == PoolingMode.SUM: + return torch.ones(1, self.D, device=self.device, dtype=self.output_dtype) + + return torch.zeros( + indices.size(0), self.D, device=self.device, dtype=self.output_dtype + ) + + def split_embedding_weights_with_scale_bias( + self, split_scale_bias_mode: int = 1 + ) -> List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]: + if split_scale_bias_mode == 2: + # pyre-ignore + return self.split_embedding_weights + raise NotImplementedError() + + +def mock_tbe_from_tbe(tbe: IntNBitTableBatchedEmbeddingBagsCodegen) -> MockTBE: + return MockTBE( + tbe.embedding_specs, + tbe.current_device, + tbe.output_dtype, + tbe.pooling_mode, + ) + + +def replace_registered_tbes_with_mock_tbes(M: torch.nn.Module, path: str = "") -> None: + for child_name, child in M.named_children(): + child_path = f"{path}.{child_name}" if path else child_name + if isinstance(child, IntNBitTableBatchedEmbeddingBagsCodegen): + M.register_module( + child_name, + mock_tbe_from_tbe(child), + ) + else: + replace_registered_tbes_with_mock_tbes(child, child_path) + + +def replace_sharded_quant_modules_tbes_with_mock_tbes(M: torch.nn.Module) -> None: + for m in M.modules(): + if isinstance(m, ShardedQuantEmbeddingBagCollection): + for lookup in m._lookups: + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is + # not a function. + for lookup_per_rank in lookup._embedding_lookups_per_rank: + replace_registered_tbes_with_mock_tbes(lookup_per_rank) diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index 31e7ca85e..af201bfa6 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -5,8 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 +import logging import multiprocessing import os import unittest @@ -29,27 +32,35 @@ def __init__( world_size: int, backend: str = "gloo", local_size: Optional[int] = None, + use_deterministic_algorithms: bool = True, + disable_cuda_tf_32: bool = True, ) -> None: self.rank = rank self.world_size = world_size self.backend = backend self.local_size = local_size + self.disable_cuda_tf_32 = disable_cuda_tf_32 + + if torch.cuda.is_available() and world_size <= torch.cuda.device_count(): + self.device: torch.device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(self.device) - if backend == "nccl": - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + if self.disable_cuda_tf_32: + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False else: - device = torch.device("cpu") - self.device: torch.device = device - torch.use_deterministic_algorithms(True) - if torch.cuda.is_available(): - torch.backends.cudnn.allow_tf32 = False - torch.backends.cuda.matmul.allow_tf32 = False + self.device: torch.device = torch.device("cpu") + + if use_deterministic_algorithms: + if torch.cuda.is_available(): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.use_deterministic_algorithms(True) + self.pg: Optional[dist.ProcessGroup] = None - # pyre-ignore - def __enter__(self): + def __enter__(self) -> "MultiProcessContext": """ Override local_size after pg construction because unit test device count is larger than local_size setup. This can be problematic for twrw because we have @@ -78,11 +89,21 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None: dist.destroy_process_group(_CROSS_PG) dist.destroy_process_group(self.pg) torch.use_deterministic_algorithms(False) - if torch.cuda.is_available(): + if torch.cuda.is_available() and self.disable_cuda_tf_32: torch.backends.cudnn.allow_tf32 = True class MultiProcessTestBase(unittest.TestCase): + def __init__( + self, methodName: str = "runTest", mp_init_mode: str = "forkserver" + ) -> None: + super().__init__(methodName) + + # AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail + # Therefore we use spawn for HIP runtime until AMD fixes the issue + self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn" + logging.info(f"Using {self._mp_init_mode} for multiprocessing") + @seed_and_log def setUp(self) -> None: os.environ["MASTER_ADDR"] = str("localhost") @@ -111,11 +132,11 @@ def _run_multi_process_test( ..., None, ], - world_size: int, + world_size: int = 2, # pyre-ignore **kwargs, ) -> None: - ctx = multiprocessing.get_context("forkserver") + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs["rank"] = rank @@ -141,7 +162,7 @@ def _run_multi_process_test_per_rank( world_size: int, kwargs_per_rank: List[Dict[str, Any]], ) -> None: - ctx = multiprocessing.get_context("forkserver") + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs = {} @@ -158,3 +179,50 @@ def _run_multi_process_test_per_rank( for p in processes: p.join() self.assertEqual(0, p.exitcode) + + +def run_multi_process_func( + func: Callable[ + [int, int, ...], # rank, world_size, ... + None, + ], + multiprocessing_method: str = "spawn", + use_deterministic_algorithms: bool = True, + world_size: int = 2, + # pyre-ignore + **kwargs, +) -> None: + """ """ + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + torch.use_deterministic_algorithms(use_deterministic_algorithms) + if torch.cuda.is_available(): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + if world_size == 1: + kwargs["world_size"] = 1 + kwargs["rank"] = 0 + func(**kwargs) + return + ctx = multiprocessing.get_context(multiprocessing_method) + processes = [] + for rank in range(world_size): + kwargs["rank"] = rank + kwargs["world_size"] = world_size + p = ctx.Process( + target=func, + name=f"rank{rank}", + kwargs=kwargs, + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + if p.exitcode != 0: + print(p) diff --git a/torchrec/distributed/test_utils/test_input.py b/torchrec/distributed/test_utils/test_input.py new file mode 100644 index 000000000..6f5ce7ef0 --- /dev/null +++ b/torchrec/distributed/test_utils/test_input.py @@ -0,0 +1,614 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import cast, List, Optional, Tuple, Union + +import torch +from tensordict import TensorDict +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Pipelineable + + +@dataclass +class ModelInput(Pipelineable): + """ + basic model input for a simple standard RecSys model + the input is a training data batch that contains: + 1. a tensor for dense features + 2. a KJT for unweighted sparse features + 3. a KJT for weighted sparse features + 4. a tensor for the label + """ + + float_features: torch.Tensor + idlist_features: Optional[KeyedJaggedTensor] + idscore_features: Optional[KeyedJaggedTensor] + label: torch.Tensor + + def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": + return ModelInput( + float_features=self.float_features.to( + device=device, non_blocking=non_blocking + ), + idlist_features=( + self.idlist_features.to(device=device, non_blocking=non_blocking) + if self.idlist_features is not None + else None + ), + idscore_features=( + self.idscore_features.to(device=device, non_blocking=non_blocking) + if self.idscore_features is not None + else None + ), + label=self.label.to(device=device, non_blocking=non_blocking), + ) + + def record_stream(self, stream: torch.Stream) -> None: + """ + need to explicitly call `record_stream` for non-pytorch native object (KJT) + """ + self.float_features.record_stream(stream) + if isinstance(self.idlist_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if isinstance(self.idscore_features, KeyedJaggedTensor): + self.idscore_features.record_stream(stream) + self.label.record_stream(stream) + + @classmethod + def generate_global_and_local_batches( + cls, + world_size: int, + batch_size: int = 1, + tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + weighted_tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + num_float_features: int = 16, + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + ) -> Tuple["ModelInput", List["ModelInput"]]: + """ + Returns a global (single-rank training) batch, and a list of local + (multi-rank training) batches of world_size. The data should be + consistent between the local batches and the global batch so that + they can be used for comparison and validation. + """ + + float_features_list = [ + ( + torch.zeros((batch_size, num_float_features), device=device) + if all_zeros + else torch.rand((batch_size, num_float_features), device=device) + ) + for _ in range(world_size) + ] + global_idlist_features, idlist_features_list = ( + ModelInput._create_batched_standard_kjts( + batch_size, + world_size, + tables, + pooling_avg, + tables_pooling, + False, # unweighted + max_feature_lengths, + use_offsets, + device, + indices_dtype, + offsets_dtype, + lengths_dtype, + all_zeros, + ) + if tables is not None and len(tables) > 0 + else (None, [None for _ in range(world_size)]) + ) + global_idscore_features, idscore_features_list = ( + ModelInput._create_batched_standard_kjts( + batch_size, + world_size, + weighted_tables, + pooling_avg, + tables_pooling, + True, # weighted + max_feature_lengths, + use_offsets, + device, + indices_dtype, + offsets_dtype, + lengths_dtype, + all_zeros, + ) + if weighted_tables is not None and len(weighted_tables) > 0 + else (None, [None for _ in range(world_size)]) + ) + label_list = [ + ( + torch.zeros((batch_size,), device=device) + if all_zeros + else torch.rand((batch_size,), device=device) + ) + for _ in range(world_size) + ] + global_input = ModelInput( + float_features=torch.cat(float_features_list), + idlist_features=global_idlist_features, + idscore_features=global_idscore_features, + label=torch.cat(label_list), + ) + local_inputs = [ + ModelInput( + float_features=float_features, + idlist_features=idlist_features, + idscore_features=idscore_features, + label=label, + ) + for float_features, idlist_features, idscore_features, label in zip( + float_features_list, + idlist_features_list, + idscore_features_list, + label_list, + ) + ] + return global_input, local_inputs + + @classmethod + def generate_local_batches( + cls, + world_size: int, + batch_size: int = 1, + tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + weighted_tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + num_float_features: int = 16, + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + pin_memory: bool = False, # pin_memory is needed for training job qps benchmark + ) -> List["ModelInput"]: + """ + Returns multi-rank batches (ModelInput) of world_size + """ + return [ + cls.generate( + batch_size=batch_size, + tables=tables, + weighted_tables=weighted_tables, + num_float_features=num_float_features, + pooling_avg=pooling_avg, + tables_pooling=tables_pooling, + max_feature_lengths=max_feature_lengths, + use_offsets=use_offsets, + device=device, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + all_zeros=all_zeros, + pin_memory=pin_memory, + ) + for _ in range(world_size) + ] + + @classmethod + def generate( + cls, + batch_size: int = 1, + tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + weighted_tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + num_float_features: int = 16, + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + pin_memory: bool = False, # pin_memory is needed for training job qps benchmark + ) -> "ModelInput": + """ + Returns a single batch of `ModelInput` + + The `pin_memory()` call for all KJT tensors are important for training benchmark, and + also valid argument for the prod training scenario: TrainModelInput should be created + on pinned memory for a fast transfer to gpu. For more on pin_memory: + https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory + """ + float_features = ( + torch.zeros((batch_size, num_float_features), device=device) + if all_zeros + else torch.rand((batch_size, num_float_features), device=device) + ) + idlist_features = ( + ModelInput.create_standard_kjt( + batch_size=batch_size, + tables=tables, + pooling_avg=pooling_avg, + tables_pooling=tables_pooling, + weighted=False, # unweighted + max_feature_lengths=max_feature_lengths, + use_offsets=use_offsets, + device=device, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + all_zeros=all_zeros, + pin_memory=pin_memory, + ) + if tables is not None and len(tables) > 0 + else None + ) + idscore_features = ( + ModelInput.create_standard_kjt( + batch_size=batch_size, + tables=weighted_tables, + pooling_avg=pooling_avg, + tables_pooling=tables_pooling, + weighted=False, # weighted + max_feature_lengths=max_feature_lengths, + use_offsets=use_offsets, + device=device, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + all_zeros=all_zeros, + pin_memory=pin_memory, + ) + if weighted_tables is not None and len(weighted_tables) > 0 + else None + ) + label = ( + torch.zeros((batch_size,), device=device) + if all_zeros + else torch.rand((batch_size,), device=device) + ) + if pin_memory: + float_features = float_features.pin_memory() + label = label.pin_memory() + return ModelInput( + float_features=float_features, + idlist_features=idlist_features, + idscore_features=idscore_features, + label=label, + ) + + @staticmethod + def _create_features_lengths_indices( + batch_size: int, + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + max_feature_lengths: Optional[List[int]] = None, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + ) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]: + """ + Create keys, lengths, and indices for a KeyedJaggedTensor from embedding table configs. + + Returns: + Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]: + Feature names, per-feature lengths, and per-feature indices. + """ + pooling_factor_per_feature: List[int] = [] + num_embeddings_per_feature: List[int] = [] + max_length_per_feature: List[Optional[int]] = [] + features: List[str] = [] + for tid, table in enumerate(tables): + pooling_factor = ( + tables_pooling[tid] if tables_pooling is not None else pooling_avg + ) + max_feature_length = ( + max_feature_lengths[tid] if max_feature_lengths is not None else None + ) + features.extend(table.feature_names) + for _ in table.feature_names: + pooling_factor_per_feature.append(pooling_factor) + num_embeddings_per_feature.append( + table.num_embeddings_post_pruning or table.num_embeddings + ) + max_length_per_feature.append(max_feature_length) + + lengths_per_feature: List[torch.Tensor] = [] + indices_per_feature: List[torch.Tensor] = [] + + for pooling_factor, num_embeddings, max_length in zip( + pooling_factor_per_feature, + num_embeddings_per_feature, + max_length_per_feature, + ): + # lengths + _lengths = torch.max( + torch.normal( + pooling_factor, + pooling_factor / 10, # std + [batch_size], + device=device, + ), + torch.tensor(1.0, device=device), + ).to(lengths_dtype) + if max_length: + _lengths = torch.clamp(_lengths, max=max_length) + lengths_per_feature.append(_lengths) + + # indices + num_indices = cast(int, torch.sum(_lengths).item()) + _indices = ( + torch.zeros( + (num_indices,), + dtype=indices_dtype, + device=device, + ) + if all_zeros + else torch.randint( + 0, + num_embeddings, + (num_indices,), + dtype=indices_dtype, + device=device, + ) + ) + indices_per_feature.append(_indices) + return features, lengths_per_feature, indices_per_feature + + @staticmethod + def _assemble_kjt( + features: List[str], + lengths_per_feature: List[torch.Tensor], + indices_per_feature: List[torch.Tensor], + weighted: bool = False, + device: Optional[torch.device] = None, + use_offsets: bool = False, + offsets_dtype: torch.dtype = torch.int64, + pin_memory: bool = False, + ) -> KeyedJaggedTensor: + """ + Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices. + + This method is used to generate corresponding local_batches and global_batch KJTs. + It concatenates the lengths and indices for each feature to form a complete KJT. + + The `pin_memory()` call for all KJT tensors are important for training benchmark, and + also valid argument for the prod training scenario: TrainModelInput should be created + on pinned memory for a fast transfer to gpu. For more on pin_memory: + https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory + """ + + lengths = torch.cat(lengths_per_feature) + indices = torch.cat(indices_per_feature) + offsets = None + weights = torch.rand((indices.numel(),), device=device) if weighted else None + if use_offsets: + offsets = torch.cat( + [torch.tensor([0], device=device), lengths.cumsum(0)] + ).to(offsets_dtype) + lengths = None + if pin_memory: + indices = indices.pin_memory() + lengths = lengths.pin_memory() if lengths else None + weights = weights.pin_memory() if weights else None + offsets = offsets.pin_memory() if offsets else None + return KeyedJaggedTensor(features, indices, weights, lengths, offsets) + + @staticmethod + def create_standard_kjt( + batch_size: int, + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + weighted: bool = False, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + pin_memory: bool = False, + ) -> KeyedJaggedTensor: + features, lengths_per_feature, indices_per_feature = ( + ModelInput._create_features_lengths_indices( + batch_size=batch_size, + tables=tables, + pooling_avg=pooling_avg, + tables_pooling=tables_pooling, + max_feature_lengths=max_feature_lengths, + device=device, + indices_dtype=indices_dtype, + lengths_dtype=lengths_dtype, + all_zeros=all_zeros, + ) + ) + return ModelInput._assemble_kjt( + features=features, + lengths_per_feature=lengths_per_feature, + indices_per_feature=indices_per_feature, + weighted=weighted, + device=device, + use_offsets=use_offsets, + offsets_dtype=offsets_dtype, + pin_memory=pin_memory, + ) + + @staticmethod + def _create_batched_standard_kjts( + batch_size: int, + world_size: int, + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + weighted: bool = False, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + ) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]: + """ + generate a global KJT and corresponding per-rank KJTs, the data are the same + so that they can be used for result comparison. + """ + data_per_rank = [ + ModelInput._create_features_lengths_indices( + batch_size, + tables, + pooling_avg, + tables_pooling, + max_feature_lengths, + device, + indices_dtype, + lengths_dtype, + all_zeros, + ) + for _ in range(world_size) + ] + features = data_per_rank[0][0] + local_kjts = [ + ModelInput._assemble_kjt( + features, + lengths_per_feature, + indices_per_feature, + weighted, + device, + use_offsets, + offsets_dtype, + ) + for _, lengths_per_feature, indices_per_feature in data_per_rank + ] + global_lengths = [ + data_per_rank[r][1][f] + for f in range(len(features)) + for r in range(world_size) + ] + global_indices = [ + data_per_rank[r][2][f] + for f in range(len(features)) + for r in range(world_size) + ] + global_kjt = ModelInput._assemble_kjt( + features, + global_lengths, + global_indices, + weighted, + device, + use_offsets, + offsets_dtype, + ) + return global_kjt, local_kjts + + +# @dataclass +# class VbModelInput(ModelInput): +# pass + +# @staticmethod +# def _create_variable_batch_kjt() -> KeyedJaggedTensor: +# pass + +# @staticmethod +# def _merge_variable_batch_kjts(kjts: List[KeyedJaggedTensor]) -> KeyedJaggedTensor: +# pass + + +@dataclass +class TdModelInput(ModelInput): + idlist_features: TensorDict # pyre-ignore + + +@dataclass +class TestSparseNNInputConfig: + batch_size: int = 8192 + num_float_features: int = 10 + feature_pooling_avg: int = 10 + use_offsets: bool = False + dev_str: str = "" + long_kjt_indices: bool = True + long_kjt_offsets: bool = True + long_kjt_lengths: bool = True + pin_memory: bool = True + + def generate_model_input( + self, + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + weighted_tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + ) -> ModelInput: + return ModelInput.generate( + batch_size=self.batch_size, + tables=tables, + weighted_tables=weighted_tables, + num_float_features=self.num_float_features, + pooling_avg=self.feature_pooling_avg, + use_offsets=self.use_offsets, + device=torch.device(self.dev_str) if self.dev_str else None, + indices_dtype=torch.int64 if self.long_kjt_indices else torch.int32, + offsets_dtype=torch.int64 if self.long_kjt_offsets else torch.int32, + lengths_dtype=torch.int64 if self.long_kjt_lengths else torch.int32, + pin_memory=self.pin_memory, + ) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index c91131724..e06821e47 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -5,11 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import copy +import random from dataclasses import dataclass -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn +from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -21,8 +26,20 @@ ) from torchrec.distributed.fused_embedding import FusedEmbeddingCollectionSharder from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder -from torchrec.distributed.types import QuantizedCommCodecs -from torchrec.inference.modules import CopyableMixin +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, +) +from torchrec.distributed.mc_embeddingbag import ( + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.distributed.utils import CopyableMixin +from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( BaseEmbeddingConfig, EmbeddingBagConfig, @@ -31,15 +48,24 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection from torchrec.modules.feature_processor import PositionWeightedProcessor -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Pipelineable @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: KeyedJaggedTensor - idscore_features: KeyedJaggedTensor + idlist_features: Union[KeyedJaggedTensor, TensorDict] + idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] label: torch.Tensor @staticmethod @@ -62,6 +88,16 @@ def generate( ] ] = None, variable_batch_size: bool = False, + tables_pooling: Optional[List[int]] = None, + weighted_tables_pooling: Optional[List[int]] = None, + randomize_indices: bool = True, + device: Optional[torch.device] = None, + max_feature_lengths: Optional[List[int]] = None, + input_type: str = "kjt", + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch @@ -74,10 +110,40 @@ def generate( for r in range(world_size) ] + def _validate_pooling_factor( + tables: Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ], + pooling_factor: Optional[List[int]], + ) -> None: + if pooling_factor and len(pooling_factor) != len(tables): + raise ValueError( + "tables_pooling and tables must have the same length. " + f"Got {len(pooling_factor)} and {len(tables)}." + ) + + _validate_pooling_factor(tables, tables_pooling) + _validate_pooling_factor(weighted_tables, weighted_tables_pooling) + idlist_features_to_num_embeddings = {} - for table in tables: - for feature in table.feature_names: - idlist_features_to_num_embeddings[feature] = table.num_embeddings + idlist_features_to_pooling_factor = {} + idlist_features_to_max_length = {} + feature_idx = 0 + for idx in range(len(tables)): + for feature in tables[idx].feature_names: + idlist_features_to_num_embeddings[feature] = ( + tables[idx].num_embeddings_post_pruning + if tables[idx].num_embeddings_post_pruning is not None + else tables[idx].num_embeddings + ) + idlist_features_to_max_length[feature] = ( + max_feature_lengths[feature_idx] if max_feature_lengths else None + ) + if tables_pooling is not None: + idlist_features_to_pooling_factor[feature] = tables_pooling[idx] + feature_idx += 1 idlist_features = list(idlist_features_to_num_embeddings.keys()) idscore_features = [ @@ -85,83 +151,207 @@ def generate( ] idlist_ind_ranges = list(idlist_features_to_num_embeddings.values()) - idscore_ind_ranges = [table.num_embeddings for table in weighted_tables] + idscore_ind_ranges = [ + ( + table.num_embeddings_post_pruning + if table.num_embeddings_post_pruning is not None + else table.num_embeddings + ) + for table in weighted_tables + ] + + idlist_pooling_factor = list(idlist_features_to_pooling_factor.values()) + idscore_pooling_factor = weighted_tables_pooling + idlist_max_lengths = list(idlist_features_to_max_length.values()) # Generate global batch. global_idlist_lengths = [] global_idlist_indices = [] + global_idlist_offsets = [] + global_idscore_lengths = [] global_idscore_indices = [] + global_idscore_offsets = [] global_idscore_weights = [] - for ind_range in idlist_ind_ranges: - lengths_ = torch.abs( - torch.randn(batch_size * world_size) + pooling_avg - ).int() + for idx in range(len(idlist_ind_ranges)): + ind_range = idlist_ind_ranges[idx] + + if idlist_pooling_factor: + lengths_ = torch.max( + torch.normal( + idlist_pooling_factor[idx], + idlist_pooling_factor[idx] / 10, + [batch_size * world_size], + device=device, + ), + torch.tensor(1.0, device=device), + ).to(lengths_dtype) + else: + lengths_ = torch.abs( + torch.randn(batch_size * world_size, device=device) + pooling_avg, + ).to(lengths_dtype) + + if idlist_max_lengths[idx]: + lengths_ = torch.clamp(lengths_, max=idlist_max_lengths[idx]) + if variable_batch_size: - lengths = torch.zeros(batch_size * world_size).int() + lengths = torch.zeros(batch_size * world_size, device=device).to( + lengths_dtype + ) for r in range(world_size): - lengths[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] = lengths_[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] + lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( + lengths_[ + r * batch_size : r * batch_size + batch_size_by_rank[r] + ] + ) else: lengths = lengths_ + num_indices = cast(int, torch.sum(lengths).item()) - indices = torch.randint(0, ind_range, (num_indices,)) + + if randomize_indices: + indices = torch.randint( + 0, + ind_range, + (num_indices,), + dtype=indices_dtype, + device=device, + ) + else: + indices = torch.zeros( + (num_indices,), + dtype=indices_dtype, + device=device, + ) + + # Calculate offsets from lengths + offsets = torch.cat( + [torch.tensor([0], device=device), lengths.cumsum(0)] + ).to(offsets_dtype) + global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - global_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) + global_idlist_offsets.append(offsets) - for ind_range in idscore_ind_ranges: + for idx, ind_range in enumerate(idscore_ind_ranges): lengths_ = torch.abs( - torch.randn(batch_size * world_size) + pooling_avg - ).int() + torch.randn(batch_size * world_size, device=device) + + ( + idscore_pooling_factor[idx] + if idscore_pooling_factor + else pooling_avg + ) + ).to(lengths_dtype) + if variable_batch_size: - lengths = torch.zeros(batch_size * world_size).int() + lengths = torch.zeros(batch_size * world_size, device=device).to( + lengths_dtype + ) for r in range(world_size): - lengths[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] = lengths_[ - r * batch_size : r * batch_size + batch_size_by_rank[r] - ] + lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( + lengths_[ + r * batch_size : r * batch_size + batch_size_by_rank[r] + ] + ) else: lengths = lengths_ + num_indices = cast(int, torch.sum(lengths).item()) - indices = torch.randint(0, ind_range, (num_indices,)) - weights = torch.rand((num_indices,)) + + if randomize_indices: + indices = torch.randint( + 0, + # pyre-ignore [6] + ind_range, + (num_indices,), + dtype=indices_dtype, + device=device, + ) + else: + indices = torch.zeros( + (num_indices,), + dtype=indices_dtype, + device=device, + ) + weights = torch.rand((num_indices,), device=device) + # Calculate offsets from lengths + offsets = torch.cat( + [torch.tensor([0], device=device), lengths.cumsum(0)] + ).to(offsets_dtype) + global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - global_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), - ) - if global_idscore_indices - else None - ) + global_idscore_offsets.append(offsets) + + if input_type == "kjt": + global_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + offsets=torch.cat(global_idlist_offsets) if use_offsets else None, + lengths=torch.cat(global_idlist_lengths) if not use_offsets else None, + ) + + global_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + offsets=torch.cat(global_idscore_offsets) if use_offsets else None, + lengths=( + torch.cat(global_idscore_lengths) if not use_offsets else None + ), + weights=torch.cat(global_idscore_weights), + ) + if global_idscore_indices + else None + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, global_idlist_indices, global_idlist_lengths + ) + } + global_idlist_input = TensorDict(source=dict_of_nt) + + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + global_idscore_input = None + else: + raise ValueError(f"For weighted features, unknown input type {input_type}") - global_float = torch.rand((batch_size * world_size, num_float_features)) - global_label = torch.rand(batch_size * world_size) + if randomize_indices: + global_float = torch.rand( + (batch_size * world_size, num_float_features), device=device + ) + global_label = torch.rand(batch_size * world_size, device=device) + else: + global_float = torch.zeros( + (batch_size * world_size, num_float_features), device=device + ) + global_label = torch.zeros(batch_size * world_size, device=device) # Split global batch into local batches. local_inputs = [] + for r in range(world_size): local_idlist_lengths = [] local_idlist_indices = [] + local_idlist_offsets = [] + local_idscore_lengths = [] local_idscore_indices = [] local_idscore_weights = [] + local_idscore_offsets = [] - for lengths, indices in zip(global_idlist_lengths, global_idlist_indices): + for lengths, indices, offsets in zip( + global_idlist_lengths, global_idlist_indices, global_idlist_offsets + ): local_idlist_lengths.append( lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] ) @@ -171,9 +361,15 @@ def generate( local_idlist_indices.append( indices[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) + local_idlist_offsets.append( + offsets[r * batch_size : r * batch_size + batch_size_by_rank[r] + 1] + ) - for lengths, indices, weights in zip( - global_idscore_lengths, global_idscore_indices, global_idscore_weights + for lengths, indices, weights, offsets in zip( + global_idscore_lengths, + global_idscore_indices, + global_idscore_weights, + global_idscore_offsets, ): local_idscore_lengths.append( lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] @@ -188,29 +384,63 @@ def generate( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - local_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) + local_idscore_offsets.append( + offsets[r * batch_size : r * batch_size + batch_size_by_rank[r] + 1] + ) - local_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), + if input_type == "kjt": + local_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + offsets=torch.cat(local_idlist_offsets) if use_offsets else None, + lengths=( + torch.cat(local_idlist_lengths) if not use_offsets else None + ), + ) + + local_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + offsets=( + torch.cat(local_idscore_offsets) if use_offsets else None + ), + lengths=( + torch.cat(local_idscore_lengths) + if not use_offsets + else None + ), + weights=torch.cat(local_idscore_weights), + ) + if local_idscore_indices + else None + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, + local_idlist_indices, + local_idlist_lengths, + ) + } + local_idlist_input = TensorDict(source=dict_of_nt) + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + local_idscore_input = None + else: + raise ValueError( + f"For weighted features, unknown input type {input_type}" ) - if local_idscore_indices - else None - ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_kjt, - # pyre-fixme[6]: For 3rd param expected `KeyedJaggedTensor` but got - # `Optional[KeyedJaggedTensor]`. - idscore_features=local_idscore_kjt, + idlist_features=local_idlist_input, + idscore_features=local_idscore_input, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -218,15 +448,377 @@ def generate( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_kjt, - # pyre-fixme[6]: For 3rd param expected `KeyedJaggedTensor` but got - # `Optional[KeyedJaggedTensor]`. - idscore_features=global_idscore_kjt, + idlist_features=global_idlist_input, + idscore_features=global_idscore_input, label=global_label, ), local_inputs, ) + @staticmethod + def _generate_variable_batch_local_features( + feature_num_embeddings: Dict[str, int], + average_batch_size: int, + world_size: int, + dedup_factor: int, + values_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + lengths_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + strides_per_rank_per_feature: Dict[int, Dict[str, int]], + inverse_indices_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + weights_per_rank_per_feature: Optional[Dict[int, Dict[str, torch.Tensor]]], + use_offsets: bool, + indices_dtype: torch.dtype, + offsets_dtype: torch.dtype, + lengths_dtype: torch.dtype, + ) -> List[KeyedJaggedTensor]: + local_kjts = [] + keys = list(feature_num_embeddings.keys()) + + for rank in range(world_size): + lengths_per_rank_per_feature[rank] = {} + values_per_rank_per_feature[rank] = {} + strides_per_rank_per_feature[rank] = {} + inverse_indices_per_rank_per_feature[rank] = {} + + if weights_per_rank_per_feature is not None: + weights_per_rank_per_feature[rank] = {} + + for key, num_embeddings in feature_num_embeddings.items(): + batch_size = random.randint(1, average_batch_size * dedup_factor - 1) + lengths = torch.randint( + low=0, high=5, size=(batch_size,), dtype=lengths_dtype + ) + lengths_per_rank_per_feature[rank][key] = lengths + lengths_sum = sum(lengths.tolist()) + values = torch.randint( + 0, num_embeddings, (lengths_sum,), dtype=indices_dtype + ) + values_per_rank_per_feature[rank][key] = values + if weights_per_rank_per_feature is not None: + weights_per_rank_per_feature[rank][key] = torch.rand(lengths_sum) + strides_per_rank_per_feature[rank][key] = batch_size + inverse_indices_per_rank_per_feature[rank][key] = torch.randint( + 0, + batch_size, + (dedup_factor * average_batch_size,), + dtype=indices_dtype, + ) + + values = torch.cat(list(values_per_rank_per_feature[rank].values())) + lengths = torch.cat(list(lengths_per_rank_per_feature[rank].values())) + weights = ( + torch.cat(list(weights_per_rank_per_feature[rank].values())) + if weights_per_rank_per_feature is not None + else None + ) + + if use_offsets: + offsets = torch.cat( + [torch.tensor([0], dtype=offsets_dtype), lengths.cumsum(0)] + ) + local_kjts.append( + KeyedJaggedTensor( + keys=keys, + values=values, + offsets=offsets, + weights=weights, + ) + ) + else: + stride_per_key_per_rank = [ + [stride] for stride in strides_per_rank_per_feature[rank].values() + ] + inverse_indices = ( + keys, + torch.stack( + list(inverse_indices_per_rank_per_feature[rank].values()) + ), + ) + local_kjts.append( + KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + weights=weights, + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, + ) + ) + + return local_kjts + + @staticmethod + def _generate_variable_batch_global_features( + keys: List[str], + world_size: int, + global_constant_batch: bool, + values_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + lengths_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + strides_per_rank_per_feature: Dict[int, Dict[str, int]], + inverse_indices_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + weights_per_rank_per_feature: Optional[Dict[int, Dict[str, torch.Tensor]]], + use_offsets: bool, + indices_dtype: torch.dtype, + offsets_dtype: torch.dtype, + lengths_dtype: torch.dtype, + ) -> KeyedJaggedTensor: + global_values = [] + global_lengths = [] + global_stride_per_key_per_rank = [] + inverse_indices_per_feature_per_rank = [] + global_weights = [] if weights_per_rank_per_feature is not None else None + + for key in keys: + sum_stride = 0 + for rank in range(world_size): + global_values.append(values_per_rank_per_feature[rank][key]) + global_lengths.append(lengths_per_rank_per_feature[rank][key]) + if weights_per_rank_per_feature is not None: + assert global_weights is not None + global_weights.append(weights_per_rank_per_feature[rank][key]) + sum_stride += strides_per_rank_per_feature[rank][key] + inverse_indices_per_feature_per_rank.append( + inverse_indices_per_rank_per_feature[rank][key] + ) + + global_stride_per_key_per_rank.append([sum_stride]) + + inverse_indices_list: List[torch.Tensor] = [] + + for key in keys: + accum_batch_size = 0 + inverse_indices = [] + + for rank in range(world_size): + inverse_indices.append( + inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size + ) + accum_batch_size += strides_per_rank_per_feature[rank][key] + + inverse_indices_list.append(torch.cat(inverse_indices)) + + global_inverse_indices = (keys, torch.stack(inverse_indices_list)) + + if global_constant_batch: + global_offsets = [] + + for length in global_lengths: + global_offsets.append(_to_offsets(length)) + + reindexed_lengths = [] + + for length, indices in zip( + global_lengths, inverse_indices_per_feature_per_rank + ): + reindexed_lengths.append(torch.index_select(length, 0, indices)) + + lengths = torch.cat(reindexed_lengths) + reindexed_values, reindexed_weights = [], [] + + for i, (values, offsets, indices) in enumerate( + zip(global_values, global_offsets, inverse_indices_per_feature_per_rank) + ): + for idx in indices: + reindexed_values.append(values[offsets[idx] : offsets[idx + 1]]) + if global_weights is not None: + reindexed_weights.append( + global_weights[i][offsets[idx] : offsets[idx + 1]] + ) + + values = torch.cat(reindexed_values) + weights = ( + torch.cat(reindexed_weights) if global_weights is not None else None + ) + global_stride_per_key_per_rank = None + global_inverse_indices = None + + else: + values = torch.cat(global_values) + lengths = torch.cat(global_lengths) + weights = torch.cat(global_weights) if global_weights is not None else None + + if use_offsets: + offsets = torch.cat( + [torch.tensor([0], dtype=offsets_dtype), lengths.cumsum(0)] + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=global_stride_per_key_per_rank, + inverse_indices=global_inverse_indices, + ) + else: + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + weights=weights, + stride_per_key_per_rank=global_stride_per_key_per_rank, + inverse_indices=global_inverse_indices, + ) + + @staticmethod + def _generate_variable_batch_features( + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + average_batch_size: int, + world_size: int, + dedup_factor: int, + global_constant_batch: bool, + use_offsets: bool, + indices_dtype: torch.dtype, + offsets_dtype: torch.dtype, + lengths_dtype: torch.dtype, + ) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]: + is_weighted = ( + True if tables and getattr(tables[0], "is_weighted", False) else False + ) + + feature_num_embeddings = {} + + for table in tables: + for feature_name in table.feature_names: + feature_num_embeddings[feature_name] = ( + table.num_embeddings_post_pruning + if table.num_embeddings_post_pruning + else table.num_embeddings + ) + + local_kjts = [] + + values_per_rank_per_feature = {} + lengths_per_rank_per_feature = {} + strides_per_rank_per_feature = {} + inverse_indices_per_rank_per_feature = {} + weights_per_rank_per_feature = {} if is_weighted else None + + local_kjts = ModelInput._generate_variable_batch_local_features( + feature_num_embeddings=feature_num_embeddings, + average_batch_size=average_batch_size, + world_size=world_size, + dedup_factor=dedup_factor, + values_per_rank_per_feature=values_per_rank_per_feature, + lengths_per_rank_per_feature=lengths_per_rank_per_feature, + strides_per_rank_per_feature=strides_per_rank_per_feature, + inverse_indices_per_rank_per_feature=inverse_indices_per_rank_per_feature, + weights_per_rank_per_feature=weights_per_rank_per_feature, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + + global_kjt = ModelInput._generate_variable_batch_global_features( + keys=list(feature_num_embeddings.keys()), + world_size=world_size, + global_constant_batch=global_constant_batch, + values_per_rank_per_feature=values_per_rank_per_feature, + lengths_per_rank_per_feature=lengths_per_rank_per_feature, + strides_per_rank_per_feature=strides_per_rank_per_feature, + inverse_indices_per_rank_per_feature=inverse_indices_per_rank_per_feature, + weights_per_rank_per_feature=weights_per_rank_per_feature, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + + return (global_kjt, local_kjts) + + @staticmethod + def generate_variable_batch_input( + average_batch_size: int, + world_size: int, + num_float_features: int, + tables: Union[ + List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] + ], + weighted_tables: Optional[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ] = None, + pooling_avg: int = 10, + global_constant_batch: bool = False, + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + ) -> Tuple["ModelInput", List["ModelInput"]]: + torch.manual_seed(100) + random.seed(100) + dedup_factor = 2 + + global_kjt, local_kjts = ModelInput._generate_variable_batch_features( + tables=tables, + average_batch_size=average_batch_size, + world_size=world_size, + dedup_factor=dedup_factor, + global_constant_batch=global_constant_batch, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + + if weighted_tables: + global_score_kjt, local_score_kjts = ( + ModelInput._generate_variable_batch_features( + tables=weighted_tables, + average_batch_size=average_batch_size, + world_size=world_size, + dedup_factor=dedup_factor, + global_constant_batch=global_constant_batch, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + ) + else: + global_score_kjt, local_score_kjts = None, [] + + global_float = torch.rand( + (dedup_factor * average_batch_size * world_size, num_float_features) + ) + + local_model_input = [] + label_per_rank = [] + + for rank in range(world_size): + label_per_rank.append(torch.rand(dedup_factor * average_batch_size)) + local_float = global_float[ + rank + * dedup_factor + * average_batch_size : (rank + 1) + * dedup_factor + * average_batch_size + ] + local_model_input.append( + ModelInput( + idlist_features=local_kjts[rank], + idscore_features=( + local_score_kjts[rank] if local_score_kjts else None + ), + label=label_per_rank[rank], + float_features=local_float, + ), + ) + + global_model_input = ModelInput( + idlist_features=global_kjt, + idscore_features=global_score_kjt, + label=torch.cat(label_per_rank), + float_features=global_float, + ) + + return (global_model_input, local_model_input) + def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": return ModelInput( float_features=self.float_features.to( @@ -235,22 +827,20 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": idlist_features=self.idlist_features.to( device=device, non_blocking=non_blocking ), - # pyre-ignore [6] - idscore_features=self.idscore_features.to( - device=device, non_blocking=non_blocking - ) - if self.idscore_features is not None - else None, + idscore_features=( + self.idscore_features.to(device=device, non_blocking=non_blocking) + if self.idscore_features is not None + else None + ), label=self.label.to(device=device, non_blocking=non_blocking), ) - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. + def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - self.idlist_features.record_stream(stream) - if self.idscore_features is not None: + if isinstance(self.idlist_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if isinstance(self.idscore_features, KeyedJaggedTensor): self.idscore_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.label.record_stream(stream) @@ -284,17 +874,67 @@ def __init__( in_features=num_float_features, out_features=8, device=device ) - self.dummy_param = torch.nn.Parameter(torch.empty(2, device=device)) + self.dummy_param = torch.nn.Parameter(torch.zeros(2, device=device)) self.register_buffer( "dummy_buffer", - torch.nn.Parameter(torch.empty(1, device=device)), + torch.nn.Parameter(torch.zeros(1, device=device)), ) def forward(self, dense_input: torch.Tensor) -> torch.Tensor: return self.linear(dense_input) -class TestOverArch(nn.Module): +class TestDHNArch(nn.Module): + """ + Simple version of a model with two linear layers. + We use this to test out recursively wrapped FSDP + + Args: + in_feature: the size of input dim + device: the device on which this module will be placed. + + Call Args: + input: input tensor, + + Returns: + torch.Tensor + + Example:: + + TestDHNArch() + """ + + def __init__( + self, + in_features: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + + self.device = device + self.linear0 = nn.Linear( + in_features=in_features, out_features=16, device=device + ) + self.linear1 = nn.Linear(in_features=16, out_features=16, device=device) + + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + return self.linear1(self.linear0(input)) + + +@torch.fx.wrap +def _concat( + dense: torch.Tensor, + sparse_embeddings: List[torch.Tensor], +) -> torch.Tensor: + return torch.cat([dense] + sparse_embeddings, dim=1) + + +class TestOverArchRegroupModule(nn.Module): """ Basic nn.Module for testing @@ -317,14 +957,80 @@ def __init__( self, tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], + embedding_names: Optional[List[str]] = None, device: Optional[torch.device] = None, ) -> None: super().__init__() if device is None: device = torch.device("cpu") - self._features: List[str] = [ - feature for table in tables for feature in table.feature_names + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names ] + in_features = ( + 8 + + sum([table.embedding_dim * len(table.feature_names) for table in tables]) + + sum( + [ + table.embedding_dim * len(table.feature_names) + for table in weighted_tables + ] + ) + ) + self.dhn_arch: nn.Module = TestDHNArch(in_features, device) + self.regroup_module = KTRegroupAsDict( + [self._embedding_names, self._weighted_features], + ["unweighted", "weighted"], + ) + + def forward( + self, + dense: torch.Tensor, + sparse: KeyedTensor, + ) -> torch.Tensor: + pooled_emb = self.regroup_module([sparse]) + values = list(pooled_emb.values()) + return self.dhn_arch(_concat(dense, values)) + + +class TestOverArch(nn.Module): + """ + Basic nn.Module for testing + + Args: + device + + Call Args: + dense: torch.Tensor, + sparse: KeyedTensor, + + Returns: + torch.Tensor + + Example:: + + TestOverArch() + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + embedding_names: Optional[List[str]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) self._weighted_features: List[str] = [ feature for table in weighted_tables for feature in table.feature_names ] @@ -338,35 +1044,98 @@ def __init__( ] ) ) - self.linear: nn.modules.Linear = nn.Linear( - in_features=in_features, out_features=16, device=device + self.dhn_arch: nn.Module = TestDHNArch(in_features, device) + + def forward( + self, + dense: torch.Tensor, + sparse: KeyedTensor, + ) -> torch.Tensor: + sparse_regrouped: List[torch.Tensor] = KeyedTensor.regroup( + [sparse], [self._embedding_names, self._weighted_features] ) + return self.dhn_arch(_concat(dense, sparse_regrouped)) + + +class TestOverArchLarge(nn.Module): + """ + Basic nn.Module for testing, w 5/ layers. + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + embedding_names: Optional[List[str]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names + ] + in_features = ( + 8 + + sum([table.embedding_dim * len(table.feature_names) for table in tables]) + + sum( + [ + table.embedding_dim * len(table.feature_names) + for table in weighted_tables + ] + ) + ) + out_features = 1000 + layers = [ + torch.nn.Linear( + in_features=in_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + for _ in range(5): + layers += [ + torch.nn.Linear( + in_features=out_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + self.overarch = torch.nn.Sequential(*layers) + def forward( self, dense: torch.Tensor, sparse: KeyedTensor, ) -> torch.Tensor: - ret_list = [] - ret_list.append(dense) - for feature_name in self._features: - ret_list.append(sparse[feature_name]) - for feature_name in self._weighted_features: - ret_list.append(sparse[feature_name]) - return self.linear(torch.cat(ret_list, dim=1)) + ret_list = [dense] + ret_list.extend( + KeyedTensor.regroup( + [sparse], [self._embedding_names, self._weighted_features] + ) + ) + return self.overarch(torch.cat(ret_list, dim=1)) @torch.fx.wrap def _post_sparsenn_forward( ebc: KeyedTensor, - fp_ebc: KeyedTensor, - w_ebc: KeyedTensor, + fp_ebc: Optional[KeyedTensor], + w_ebc: Optional[KeyedTensor], batch_size: Optional[int] = None, ) -> KeyedTensor: if batch_size is None or ebc.values().size(0) == batch_size: ebc_values = ebc.values() fp_ebc_values = fp_ebc.values() if fp_ebc is not None else None - w_ebc_values = w_ebc.values() + w_ebc_values = w_ebc.values() if w_ebc is not None else None else: ebc_values = torch.zeros( batch_size, @@ -385,30 +1154,56 @@ def _post_sparsenn_forward( fp_ebc_values[: fp_ebc.values().size(0), :] = fp_ebc.values() else: fp_ebc_values = None - w_ebc_values = torch.zeros( - batch_size, - w_ebc.values().size(1), - dtype=w_ebc.values().dtype, - device=w_ebc.values().device, + if w_ebc is not None: + w_ebc_values = torch.zeros( + batch_size, + w_ebc.values().size(1), + dtype=w_ebc.values().dtype, + device=w_ebc.values().device, + ) + w_ebc_values[: w_ebc.values().size(0), :] = w_ebc.values() + else: + w_ebc_values = None + + if fp_ebc is None and w_ebc is None: + return KeyedTensor( + keys=ebc.keys(), + length_per_key=ebc.length_per_key(), + values=ebc_values, ) - w_ebc_values[: w_ebc.values().size(0), :] = w_ebc.values() - result = ( - KeyedTensor( + elif fp_ebc is None and w_ebc is not None: + return KeyedTensor( keys=ebc.keys() + w_ebc.keys(), length_per_key=ebc.length_per_key() + w_ebc.length_per_key(), - values=torch.cat([ebc_values, w_ebc_values], dim=1), + values=torch.cat( + [ebc_values, torch.jit._unwrap_optional(w_ebc_values)], dim=1 + ), ) - if fp_ebc is None - else KeyedTensor( + elif fp_ebc is not None and w_ebc is None: + return KeyedTensor( + keys=ebc.keys() + fp_ebc.keys(), + length_per_key=ebc.length_per_key() + fp_ebc.length_per_key(), + values=torch.cat( + [ebc_values, torch.jit._unwrap_optional(fp_ebc_values)], dim=1 + ), + ) + else: + assert fp_ebc is not None and w_ebc is not None + return KeyedTensor( keys=ebc.keys() + fp_ebc.keys() + w_ebc.keys(), length_per_key=ebc.length_per_key() + fp_ebc.length_per_key() + w_ebc.length_per_key(), - # pyre-ignore[6] - values=torch.cat([ebc_values, fp_ebc_values, w_ebc_values], dim=1), + # Comment to torch.jit._unwrap_optional fp_ebc_values is inferred as Optional[Tensor] as it can be None when fp_ebc is None. But at this point we now that it has a value and doing jit._unwrap_optional will tell jit to treat it as Tensor type. + values=torch.cat( + [ + ebc_values, + torch.jit._unwrap_optional(fp_ebc_values), + torch.jit._unwrap_optional(w_ebc_values), + ], + dim=1, + ), ) - ) - return result class TestSparseArch(nn.Module): @@ -431,60 +1226,57 @@ def __init__( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None, - max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, + max_feature_lengths: Optional[Dict[str, int]] = None, ) -> None: super().__init__() if device is None: device = torch.device("cpu") self.fps: Optional[nn.ModuleList] = None - self.fp_ebc: Optional[EmbeddingBagCollection] = None - if max_feature_lengths_list is not None: - self.fps = nn.ModuleList( - [ - PositionWeightedProcessor( - max_feature_lengths=max_feature_lengths, - device=device - if device != torch.device("meta") - else torch.device("cpu"), - ) - for max_feature_lengths in max_feature_lengths_list - ] - ) - normal_id_list_tables = [] - fp_id_list_tables = [] - for table in tables: - # the key set of feature_processor is either subset or none in the feature_names - if set(table.feature_names).issubset( - set(max_feature_lengths_list[0].keys()) - ): - fp_id_list_tables.append(table) - else: - normal_id_list_tables.append(table) + self.fp_ebc: Optional[FeatureProcessedEmbeddingBagCollection] = None + + if max_feature_lengths is not None: + fp_tables_names = set(max_feature_lengths.keys()) + normal_tables_names = {table.name for table in tables} - fp_tables_names self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( - tables=normal_id_list_tables, + tables=[table for table in tables if table.name in normal_tables_names], device=device, ) - self.fp_ebc: EmbeddingBagCollection = EmbeddingBagCollection( - tables=fp_id_list_tables, - device=device, - is_weighted=True, + + fp = PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + device=( + device if device != torch.device("meta") else torch.device("cpu") + ), + ) + self.fp_ebc = FeatureProcessedEmbeddingBagCollection( + embedding_bag_collection=EmbeddingBagCollection( + tables=[table for table in tables if table.name in fp_tables_names], + device=device, + is_weighted=True, + ), + feature_processors=fp, ) else: self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( tables=tables, device=device, ) - self.weighted_ebc: EmbeddingBagCollection = EmbeddingBagCollection( - tables=weighted_tables, - is_weighted=True, - device=device, + + self.weighted_ebc: Optional[EmbeddingBagCollection] = ( + EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + if weighted_tables + else None ) def forward( self, features: KeyedJaggedTensor, - weighted_features: KeyedJaggedTensor, + weighted_features: Optional[KeyedJaggedTensor] = None, batch_size: Optional[int] = None, ) -> KeyedTensor: fp_features = features @@ -493,8 +1285,14 @@ def forward( for fp in self.fps: fp_features = fp(fp_features) ebc = self.ebc(features) - fp_ebc = self.fp_ebc(fp_features) if self.fp_ebc is not None else None - w_ebc = self.weighted_ebc(weighted_features) + fp_ebc: Optional[KeyedTensor] = ( + self.fp_ebc(fp_features) if self.fp_ebc is not None else None + ) + w_ebc = ( + self.weighted_ebc(weighted_features) + if self.weighted_ebc is not None and weighted_features is not None + else None + ) result = _post_sparsenn_forward(ebc, fp_ebc, w_ebc, batch_size) return result @@ -556,7 +1354,11 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, - max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, + max_feature_lengths: Optional[Dict[str, int]] = None, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + over_arch_clazz: Type[nn.Module] = TestOverArch, + postproc_module: Optional[nn.Module] = None, + zch: bool = False, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -565,27 +1367,57 @@ def __init__( dense_device=dense_device, sparse_device=sparse_device, ) - if weighted_tables is None: - weighted_tables = [] - self.dense = TestDenseArch(num_float_features, dense_device) - self.sparse = TestSparseArch( - tables, - weighted_tables, - sparse_device, - max_feature_lengths_list if max_feature_lengths_list is not None else None, + if weighted_tables is None: + weighted_tables = [] + self.dense = TestDenseArch(num_float_features, dense_device) + if zch: + self.sparse: nn.Module = TestSparseArchZCH( + tables, + weighted_tables, + torch.device("meta"), + return_remapped=True, + ) + else: + self.sparse = TestSparseArch( + tables, + weighted_tables, + sparse_device, + max_feature_lengths, + ) + + embedding_names = ( + list(embedding_groups.values())[0] if embedding_groups else None + ) + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names + ] + self.over: nn.Module = over_arch_clazz( + tables, weighted_tables, embedding_names, dense_device + ) + self.register_buffer( + "dummy_ones", + torch.ones(1, device=dense_device), ) - self.over = TestOverArch(tables, weighted_tables, dense_device) + self.postproc_module = postproc_module - def forward( - self, - input: ModelInput, + def sparse_forward(self, input: ModelInput) -> KeyedTensor: + return self.sparse( + features=input.idlist_features, + weighted_features=input.idscore_features, + batch_size=input.float_features.size(0), + ) + + def dense_forward( + self, input: ModelInput, sparse_output: KeyedTensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: dense_r = self.dense(input.float_features) - sparse_r = self.sparse( - input.idlist_features, input.idscore_features, input.float_features.size(0) - ) - over_r = self.over(dense_r, sparse_r) - pred = torch.sigmoid(torch.mean(over_r, dim=1)) + over_r = self.over(dense_r, sparse_output) + pred = torch.sigmoid(torch.mean(over_r, dim=1)) + self.dummy_ones if self.training: return ( torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label), @@ -594,6 +1426,14 @@ def forward( else: return pred + def forward( + self, + input: ModelInput, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.postproc_module: + input = self.postproc_module(input) + return self.dense_forward(input, self.sparse_forward(input)) + class TestTowerInteraction(nn.Module): """ @@ -671,6 +1511,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -704,9 +1545,11 @@ def __init__( self.over = nn.Linear( in_features=8 - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + self.tower_0.interaction.linear.out_features - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + self.tower_1.interaction.linear.out_features + tables[1].embedding_dim * len(tables[1].feature_names) + weighted_tables[0].embedding_dim * len(weighted_tables[0].feature_names), @@ -766,6 +1609,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -797,11 +1641,14 @@ def __init__( self.tower_arch = EmbeddingTowerCollection(towers=[tower_0, tower_1, tower_2]) self.over = nn.Linear( in_features=8 - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + tower_0.interaction.linear.out_features - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + tower_1.interaction.linear.out_features - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + tower_2.interaction.linear.out_features, out_features=16, device=dense_device, @@ -831,14 +1678,13 @@ def __init__( kernel_type: str, fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - variable_batch_size: bool = False, ) -> None: if fused_params is None: fused_params = {} self._sharding_type = sharding_type self._kernel_type = kernel_type - super().__init__(fused_params, qcomm_codecs_registry, variable_batch_size) + super().__init__(fused_params, qcomm_codecs_registry) """ Restricts sharding to single type only. @@ -857,6 +1703,64 @@ def compute_kernels( return [self._kernel_type] +class TestMCSharder(ManagedCollisionCollectionSharder): + def __init__( + self, + sharding_type: str, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + self._sharding_type = sharding_type + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +class TestEBCSharderMCH( + BaseManagedCollisionEmbeddingCollectionSharder[ + ManagedCollisionEmbeddingBagCollection + ] +): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__( + TestEBCSharder( + sharding_type, kernel_type, fused_params, qcomm_codecs_registry + ), + TestMCSharder(sharding_type, qcomm_codecs_registry), + qcomm_codecs_registry=qcomm_codecs_registry, + ) + + @property + def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]: + return ManagedCollisionEmbeddingBagCollection + + def shard( + self, + module: ManagedCollisionEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedManagedCollisionEmbeddingBagCollection: + if device is None: + device = torch.device("cuda") + return ShardedManagedCollisionEmbeddingBagCollection( + module, + params, + # pyre-ignore [6] + ebc_sharder=self._e_sharder, + mc_sharder=self._mc_sharder, + env=env, + device=device, + ) + + class TestFusedEBCSharder(FusedEmbeddingBagCollectionSharder): def __init__( self, @@ -1004,3 +1908,467 @@ def _get_default_rtol_and_atol( actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0)) expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0)) return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) + + +class TestPreprocNonWeighted(nn.Module): + """ + Basic module for testing + + Args: None + Examples: + >>> TestPreprocNonWeighted() + Returns: + List[KeyedJaggedTensor] + """ + + def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: + """ + Selects 3 features from a specific KJT + """ + # split + jt_0 = kjt["feature_0"] + jt_1 = kjt["feature_1"] + jt_2 = kjt["feature_2"] + + # merge only features 0,1,2, removing feature 3 + return [ + KeyedJaggedTensor.from_jt_dict( + { + "feature_0": jt_0, + "feature_1": jt_1, + "feature_2": jt_2, + } + ) + ] + + +class TestPreprocWeighted(nn.Module): + """ + Basic module for testing + + Args: None + Examples: + >>> TestPreprocWeighted() + Returns: + List[KeyedJaggedTensor] + """ + + def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: + """ + Selects 1 feature from specific weighted KJT + """ + + # split + jt_0 = kjt["weighted_feature_0"] + + # keep only weighted_feature_0 + return [ + KeyedJaggedTensor.from_jt_dict( + { + "weighted_feature_0": jt_0, + } + ) + ] + + +class TestModelWithPreproc(nn.Module): + """ + Basic module with up to 3 postproc modules: + - postproc on idlist_features for non-weighted EBC + - postproc on idscore_features for weighted EBC + - optional postproc on model input shared by both EBCs + + Args: + tables, + weighted_tables, + device, + postproc_module, + num_float_features, + run_postproc_inline, + + Example: + >>> TestModelWithPreproc(tables, weighted_tables, device) + + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + postproc_module: Optional[nn.Module] = None, + num_float_features: int = 10, + run_postproc_inline: bool = False, + ) -> None: + super().__init__() + self.dense = TestDenseArch(num_float_features, device) + + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + self.postproc_nonweighted = TestPreprocNonWeighted() + self.postproc_weighted = TestPreprocWeighted() + self._postproc_module = postproc_module + self._run_postproc_inline = run_postproc_inline + + def forward( + self, + input: ModelInput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs preprco for EBC and weighted EBC, optionally runs postproc for input + + Args: + input + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + modified_input = input + + if self._postproc_module is not None: + modified_input = self._postproc_module(modified_input) + elif self._run_postproc_inline: + idlist_features = modified_input.idlist_features + modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync( + idlist_features.keys(), # pyre-ignore [6] + idlist_features.values(), # pyre-ignore [6] + idlist_features.lengths(), # pyre-ignore [16] + ) + + modified_idlist_features = self.postproc_nonweighted( + modified_input.idlist_features + ) + modified_idscore_features = self.postproc_weighted( + modified_input.idscore_features + ) + ebc_out = self.ebc(modified_idlist_features[0]) + weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0]) + + pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1) + return pred.sum(), pred + + +class TestModelWithPreprocCollectionArgs(nn.Module): + """ + Basic module with up to 3 postproc modules: + - postproc on idlist_features for non-weighted EBC + - postproc on idscore_features for weighted EBC + - postproc_inner on model input shared by both EBCs + - postproc_outer providing input to postproc_b (aka nested postproc) + + Args: + tables, + weighted_tables, + device, + postproc_module_outer, + postproc_module_nested, + num_float_features, + + Example: + >>> TestModelWithPreprocWithListArg(tables, weighted_tables, device) + + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + + CONST_DICT_KEY = "const" + INPUT_TENSOR_DICT_KEY = "tensor_from_input" + POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc" + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + postproc_module_outer: nn.Module, + postproc_module_nested: nn.Module, + num_float_features: int = 10, + ) -> None: + super().__init__() + self.dense = TestDenseArch(num_float_features, device) + + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + self.postproc_nonweighted = TestPreprocNonWeighted() + self.postproc_weighted = TestPreprocWeighted() + self._postproc_module_outer = postproc_module_outer + self._postproc_module_nested = postproc_module_nested + + def forward( + self, + input: ModelInput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs preproc for EBC and weighted EBC, optionally runs postproc for input + + Args: + input + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + modified_input = input + + outer_postproc_input = self._postproc_module_outer(modified_input) + + preproc_input_list = [ + 1, + modified_input.float_features, + outer_postproc_input, + ] + preproc_input_dict = { + self.CONST_DICT_KEY: 1, + self.INPUT_TENSOR_DICT_KEY: modified_input.float_features, + self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input, + } + + modified_input = self._postproc_module_nested( + modified_input, preproc_input_list, preproc_input_dict + ) + + modified_idlist_features = self.postproc_nonweighted( + modified_input.idlist_features + ) + modified_idscore_features = self.postproc_weighted( + modified_input.idscore_features + ) + ebc_out = self.ebc(modified_idlist_features[0]) + weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0]) + + pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1) + return pred.sum(), pred + + +class TestNegSamplingModule(torch.nn.Module): + """ + Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing + + Args: + extra_input + has_params + + Example: + >>> postproc = TestNegSamplingModule(extra_input) + >>> out = postproc(in) + + Returns: + ModelInput + """ + + TEST_BUFFER_NAME = "test_buffer" + + def __init__( + self, + extra_input: ModelInput, + has_params: bool = False, + ) -> None: + super().__init__() + self._extra_input = extra_input + self.register_buffer(self.TEST_BUFFER_NAME, torch.zeros(1)) + if has_params: + self._linear: nn.Module = nn.Linear(30, 30) + + def forward(self, input: ModelInput) -> ModelInput: + """ + Appends extra features to model input + + Args: + input + Returns: + ModelInput + """ + + # merge extra input + modified_input = copy.deepcopy(input) + + # dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0] + modified_input.float_features = torch.concat( + (modified_input.float_features, self._extra_input.float_features), dim=0 + ) + + # stride will be same but features will be joined + assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) + assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) + modified_input.idlist_features = KeyedJaggedTensor.concat( + [modified_input.idlist_features, self._extra_input.idlist_features] + ) + if self._extra_input.idscore_features is not None: + # stride will be smae but features will be joined + modified_input.idscore_features = KeyedJaggedTensor.concat( + # pyre-ignore + [modified_input.idscore_features, self._extra_input.idscore_features] + ) + + # dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0] + modified_input.label = torch.concat( + (modified_input.label, self._extra_input.label), dim=0 + ) + + return modified_input + + +class TestPositionWeightedPreprocModule(torch.nn.Module): + """ + Basic module for testing + + Args: None + Example: + >>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device) + >>> out = postproc(in) + Returns: + ModelInput + """ + + def __init__( + self, max_feature_lengths: Dict[str, int], device: torch.device + ) -> None: + super().__init__() + self.fp_proc = PositionWeightedProcessor( + max_feature_lengths=max_feature_lengths, + device=device, + ) + + def forward(self, input: ModelInput) -> ModelInput: + """ + Runs PositionWeightedProcessor + + Args: + input + Returns: + ModelInput + """ + modified_input = copy.deepcopy(input) + modified_input.idlist_features = self.fp_proc(modified_input.idlist_features) + return modified_input + + +class TestSparseArchZCH(nn.Module): + """ + Basic nn.Module for testing MCH EmbeddingBagCollection + + Args: + tables + weighted_tables + device + return_remapped + + Call Args: + features + weighted_features + batch_size + + Returns: + KeyedTensor + + Example:: + + TestSparseArch() + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + return_remapped: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + for table in tables: + mc_modules[table.name] = MCHManagedCollisionModule( + zch_size=table.num_embeddings, + input_hash_size=4000, + device=device, + # TODO: If eviction interval is set to + # a low number (e.g. 2), semi-sync pipeline test will + # fail with in-place modification error during + # loss.backward(). This is because during semi-sync training, + # we run embedding module forward after autograd graph + # is constructed, but if MCH eviction happens, the + # variable used in autograd will have been modified + eviction_interval=1000, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self.ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + self.weighted_ebc: Optional[ManagedCollisionEmbeddingBagCollection] = None + if weighted_tables: + weighted_mc_modules = {} + for table in weighted_tables: + weighted_mc_modules[table.name] = MCHManagedCollisionModule( + zch_size=table.num_embeddings, + input_hash_size=4000, + device=device, + # TODO: Support MCH evictions during semi-sync + eviction_interval=1000, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + self.weighted_ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=weighted_tables, + device=device, + is_weighted=True, + ), + ManagedCollisionCollection( + managed_collision_modules=weighted_mc_modules, + embedding_configs=weighted_tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, + features: KeyedJaggedTensor, + weighted_features: Optional[KeyedJaggedTensor] = None, + batch_size: Optional[int] = None, + ) -> KeyedTensor: + """ + Runs forward and MC EBC and optionally, weighted MC EBC, + then merges the results into one KeyedTensor + + Args: + features + weighted_features + batch_size + Returns: + KeyedTensor + """ + ebc, _ = self.ebc(features) + w_ebc, _ = ( + self.weighted_ebc(weighted_features) + if self.weighted_ebc is not None and weighted_features is not None + else None + ) + result = _post_sparsenn_forward(ebc, None, w_ebc, batch_size) + return result diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index ac0d68fd9..a6ac661a8 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -5,52 +5,128 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, Tuple, Type +# pyre-strict -import torch +import unittest +from typing import Any, cast, Dict, List, Optional, Tuple, Type -import torch.distributed as dist # noqa +import torch import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType -from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase from torchrec.distributed.test_utils.test_model import TestSparseNN, TestSparseNNBase -from torchrec.distributed.test_utils.test_sharding import sharding_single_rank_test -from torchrec.distributed.types import ModuleSharder -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.test_utils import seed_and_log +from torchrec.distributed.test_utils.test_sharding import ( + create_test_sharder, + SharderType, + sharding_single_rank_test, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType +from torchrec.test_utils import seed_and_log, skip_if_asan_class +from torchrec.types import DataType class ModelParallelTestShared(MultiProcessTestBase): @seed_and_log - def setUp(self) -> None: + def setUp(self, backend: str = "nccl") -> None: super().setUp() - num_features = 4 - num_weighted_features = 2 + self.num_features = 4 + self.num_weighted_features = 2 + self.num_shared_features = 2 + + self.tables = [] + self.mean_tables = [] + self.weighted_tables = [] + self.embedding_groups = {} + self.shared_features = [] + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + def _build_tables_and_groups( + self, + data_type: DataType = DataType.FP32, + ) -> None: self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 4, + embedding_dim=(i + 2) * 8, name="table_" + str(i), feature_names=["feature_" + str(i)], + data_type=data_type, ) - for i in range(num_features) + for i in range(self.num_features) ] + shared_features_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 8, + name="table_" + str(i + self.num_features), + feature_names=["feature_" + str(i)], + data_type=data_type, + ) + for i in range(self.num_shared_features) + ] + self.tables += shared_features_tables + + self.mean_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 8, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + pooling=PoolingType.MEAN, + data_type=data_type, + ) + for i in range(self.num_features) + ] + + shared_features_tables_mean = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 8, + name="table_" + str(i + self.num_features), + feature_names=["feature_" + str(i)], + pooling=PoolingType.MEAN, + data_type=data_type, + ) + for i in range(self.num_shared_features) + ] + self.mean_tables += shared_features_tables_mean + self.weighted_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, embedding_dim=(i + 2) * 4, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], + data_type=data_type, ) - for i in range(num_weighted_features) + for i in range(self.num_weighted_features) ] - + self.shared_features = [f"feature_{i}" for i in range(self.num_shared_features)] self.embedding_groups = { - "group_0": ["feature_" + str(i) for i in range(num_features)] + "group_0": [ + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) + for table in self.tables + for feature in table.feature_names + ] } def _test_sharding( @@ -59,6 +135,8 @@ def _test_sharding( backend: str = "gloo", world_size: int = 2, local_size: Optional[int] = None, + world_size_2D: Optional[int] = None, + node_group_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, model_class: Type[TestSparseNNBase] = TestSparseNN, qcomms_config: Optional[QCommsConfig] = None, @@ -66,14 +144,29 @@ def _test_sharding( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ] = None, variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + has_weighted_tables: bool = True, + global_constant_batch: bool = False, + pooling: PoolingType = PoolingType.SUM, + data_type: DataType = DataType.FP32, + use_inter_host_allreduce: bool = False, + allow_zero_batch_size: bool = False, + custom_all_reduce: bool = False, + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, ) -> None: + self._build_tables_and_groups(data_type=data_type) self._run_multi_process_test( callable=sharding_single_rank_test, world_size=world_size, local_size=local_size, + world_size_2D=world_size_2D, + node_group_size=node_group_size, model_class=model_class, - tables=self.tables, - weighted_tables=self.weighted_tables, + tables=self.tables if pooling == PoolingType.SUM else self.mean_tables, + weighted_tables=self.weighted_tables if has_weighted_tables else None, embedding_groups=self.embedding_groups, sharders=sharders, backend=backend, @@ -82,4 +175,792 @@ def _test_sharding( qcomms_config=qcomms_config, variable_batch_size=variable_batch_size, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=global_constant_batch, + use_inter_host_allreduce=use_inter_host_allreduce, + allow_zero_batch_size=allow_zero_batch_size, + custom_all_reduce=custom_all_reduce, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + + +@skip_if_asan_class +class ModelParallelBase(ModelParallelTestShared): + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) + def test_sharding_rw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + data_type: DataType, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + + self._test_sharding( + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + data_type=data_type, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.just(EmbeddingComputeKernel.DENSE.value), + apply_optimizer_in_backward_config=st.just(None), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + # TODO - need to enable optimizer overlapped behavior for data_parallel tables + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_sharding_dp( + self, + sharder_type: str, + kernel_type: str, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + data_type: DataType, + ) -> None: + sharding_type = ShardingType.DATA_PARALLEL.value + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + data_type=data_type, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + allow_zero_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) + def test_sharding_cw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + data_type: DataType, + allow_zero_batch_size: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.COLUMN_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + data_type=data_type, + allow_zero_batch_size=allow_zero_batch_size, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_twcw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + data_type: DataType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_COLUMN_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + data_type=data_type, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.BF16, + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_tw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + data_type: DataType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + data_type=data_type, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.BF16, + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_twrw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + data_type: DataType, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.TABLE_ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + ), + global_constant_batch=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_variable_batch( + self, + sharding_type: str, + global_constant_batch: bool, + pooling: PoolingType, + data_type: DataType, + ) -> None: + if self.backend == "gloo": + # error is from FBGEMM, it says CPU even if we are on GPU. + self.skipTest( + "bounds_check_indices on CPU does not support variable length (batch size)" + ) + kernel = ( + EmbeddingComputeKernel.DENSE.value + if sharding_type == ShardingType.DATA_PARALLEL.value + else EmbeddingComputeKernel.FUSED.value + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type=SharderType.EMBEDDING_BAG_COLLECTION.value, + sharding_type=sharding_type, + kernel_type=kernel, + device=self.device, + ), + ], + backend=self.backend, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + variable_batch_per_feature=True, + has_weighted_tables=False, + global_constant_batch=global_constant_batch, + pooling=pooling, + data_type=data_type, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.COLUMN_WISE.value), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_multiple_kernels( + self, sharding_type: str, data_type: DataType + ) -> None: + if self.backend == "gloo": + self.skipTest("ProcessGroupGloo does not support reduce_scatter") + constraints = { + table.name: ParameterConstraints( + min_partition=4, + compute_kernels=( + [EmbeddingComputeKernel.FUSED.value] + if i % 2 == 0 + else [EmbeddingComputeKernel.FUSED_UVM_CACHING.value] + ), + ) + for i, table in enumerate(self.tables) + } + fused_params = {"prefetch_pipeline": True} + self._test_sharding( + # pyre-ignore[6] + sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], + backend=self.backend, + constraints=constraints, + variable_batch_per_feature=True, + has_weighted_tables=False, + data_type=data_type, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=4, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=16, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=20, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_8gpu( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=8, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + dtype=st.sampled_from([torch.int32, torch.int64]), + use_offsets=st.booleans(), + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_sharding_diff_table_index_type( + self, + dtype: torch.dtype, + use_offsets: bool, + sharder_type: str, + kernel_type: str, + ) -> None: + """ + Test that the model correctly handles input indices and offsets + with both int32 and int64 data types. + """ + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type=sharder_type, + sharding_type=ShardingType.ROW_WISE.value, # or any other relevant sharding type + kernel_type=kernel_type, + device=self.device, + ), + ), + ] + # TODO - how to pass dtype so that sampled data uses different type indices/offsets? + self._test_sharding( + sharders=sharders, + backend=self.backend, + apply_optimizer_in_backward_config=None, + variable_batch_size=False, + pooling=PoolingType.SUM, + use_offsets=use_offsets, + indices_dtype=dtype, + offsets_dtype=dtype, + lengths_dtype=dtype, ) diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index b3281ace6..bd918d97e 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -5,32 +5,64 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import os import unittest -from typing import Callable, cast, Dict, List, Optional +from collections import defaultdict +from typing import Any, Callable, cast, Dict, List, Optional, OrderedDict, Tuple +import numpy as np import torch import torch.nn as nn -from torchrec.distributed.embedding_types import EmbeddingTableConfig +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import given, settings, strategies as st, Verbosity +from torch import distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor +from torchrec import distributed as trec_dist +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + EmbeddingTableConfig, +) +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection +from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner import ( EmbeddingShardingPlanner, ParameterConstraints, Topology, ) +from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import ( _get_default_rtol_and_atol, ModelInput, + TestSparseNN, TestSparseNNBase, ) from torchrec.distributed.test_utils.test_sharding import ( copy_state_dict, + create_test_sharder, gen_model_and_input, ModelInputCallable, + SharderType, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardingEnv, + ShardingPlan, + ShardingType, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan -from torchrec.modules.embedding_configs import BaseEmbeddingConfig -from torchrec.test_utils import seed_and_log +from torchrec.modules.embedding_configs import ( + BaseEmbeddingConfig, + DataType, + EmbeddingBagConfig, + PoolingType, +) +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.test_utils import get_free_port, seed_and_log class InferenceModelParallelTestBase(unittest.TestCase): @@ -56,10 +88,12 @@ def _test_sharded_forward( tables: List[EmbeddingTableConfig], sharders: List[ModuleSharder[nn.Module]], quantize_callable: Callable[[nn.Module], nn.Module], + quantize_callable_kwargs: Dict[str, Any], dedup_features_names: Optional[List[str]] = None, dedup_tables: Optional[List[EmbeddingTableConfig]] = None, weighted_tables: Optional[List[EmbeddingTableConfig]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, + # pyre-ignore [9] generate: ModelInputCallable = ModelInput.generate, ) -> None: default_rank = 0 @@ -79,8 +113,10 @@ def _test_sharded_forward( dense_device=cuda_device, sparse_device=cuda_device, generate=generate, + indices_dtype=torch.int32, + lengths_dtype=torch.int32, ) - global_model = quantize_callable(global_model) + global_model = quantize_callable(global_model, **quantize_callable_kwargs) local_input = _inputs[0][1][default_rank].to(cuda_device) # Shard model. @@ -109,7 +145,7 @@ def _test_sharded_forward( sparse_device=torch.device("meta"), num_float_features=16, ) - local_model = quantize_callable(local_model) + local_model = quantize_callable(local_model, **quantize_callable_kwargs) planner = EmbeddingShardingPlanner( topology=Topology(world_size, "cuda"), @@ -119,8 +155,9 @@ def _test_sharded_forward( # Generate a sharded model on a default rank. local_model = DistributedModelParallel( - local_model, + module=local_model, env=ShardingEnv.from_local(world_size, default_rank), + device=cuda_device, plan=plan, sharders=sharders, init_data_parallel=False, @@ -129,7 +166,6 @@ def _test_sharded_forward( # materialize inference sharded model on one device for dense part local_model = local_model.copy(cuda_device) - # Load model state from the global model. copy_state_dict(local_model.state_dict(), global_model.state_dict()) # Run a single training step of the sharded model. @@ -143,3 +179,1001 @@ def _test_sharded_forward( # Compare predictions of sharded vs unsharded models. rtol, atol = _get_default_rtol_and_atol(global_pred, shard_pred) torch.testing.assert_close(global_pred, shard_pred, rtol=rtol, atol=atol) + + +class ModelParallelSparseOnlyBase(unittest.TestCase): + def setUp(self, backend: str = "nccl") -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + dist.init_process_group(backend=self.backend) + + def tearDown(self) -> None: + dist.destroy_process_group() + + def test_sharding_ebc_as_top_level(self) -> None: + embedding_dim = 128 + num_embeddings = 256 + ebc = EmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="large_table", + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + feature_names=["my_feature"], + pooling=PoolingType.SUM, + ), + ], + ) + + model = DistributedModelParallel(ebc, device=self.device) + + self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection)) + + def test_sharding_fused_ebc_as_top_level(self) -> None: + embedding_dim = 128 + num_embeddings = 256 + ebc = FusedEmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="large_table", + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + feature_names=["my_feature"], + pooling=PoolingType.SUM, + ), + ], + optimizer_type=torch.optim.SGD, + optimizer_kwargs={"lr": 0.02}, + ) + + model = DistributedModelParallel(ebc, device=self.device) + + self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection)) + + +class ModelParallelSingleRankBase(unittest.TestCase): + def setUp(self, backend: str = "nccl") -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + dist.init_process_group(backend=backend) + + self.batch_size = 20 + self.num_float_features = 10 + self.tables = [] + self.weighted_tables = [] + + self._create_tables() + + def tearDown(self) -> None: + dist.destroy_process_group() + del os.environ["NCCL_SOCKET_IFNAME"] + super().tearDown() + + def _create_tables(self) -> None: + pass + + def _set_table_weights_precision(self, dtype: DataType) -> None: + for table in self.tables: + table.data_type = dtype + + for weighted_table in self.weighted_tables: + weighted_table.data_type = dtype + + def _create_model(self) -> nn.Module: + return TestSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + + def _generate_batch(self) -> ModelInput: + _, local_batch = ModelInput.generate( + batch_size=self.batch_size, + world_size=1, + num_float_features=self.num_float_features, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + batch = local_batch[0].to(self.device) + return batch + + def _generate_dmps_and_batch( + self, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + constraints: Optional[Dict[str, trec_dist.planner.ParameterConstraints]] = None, + ) -> Tuple[List[DistributedModelParallel], ModelInput]: + + if constraints is None: + constraints = {} + if sharders is None: + sharders = get_default_sharders() + + batch = self._generate_batch() + + dmps = [] + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=trec_dist.comm.get_local_size(env.world_size), + world_size=env.world_size, + compute_device=self.device.type, + ), + constraints=constraints, + ) + + for _ in range(2): + # Create two identical models, wrap both in DMP + m = self._create_model() + if pg is not None: + plan = planner.collective_plan(m, sharders, pg) + else: + plan = planner.plan(m, sharders) + + dmp = DistributedModelParallel( + module=m, + init_data_parallel=False, + device=self.device, + sharders=sharders, + plan=plan, + ) + + with torch.no_grad(): + dmp(batch) + dmp.init_data_parallel() + dmps.append(dmp) + return (dmps, batch) + + def _train_models( + self, + m1: DistributedModelParallel, + m2: DistributedModelParallel, + batch: ModelInput, + ) -> None: + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + loss1.backward() + loss2.backward() + + def _eval_models( + self, + m1: DistributedModelParallel, + m2: DistributedModelParallel, + batch: ModelInput, + is_deterministic: bool = True, + ) -> None: + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + + if is_deterministic: + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + else: + rtol, atol = _get_default_rtol_and_atol(loss1, loss2) + torch.testing.assert_close(loss1, loss2, rtol=rtol, atol=atol) + rtol, atol = _get_default_rtol_and_atol(pred1, pred2) + torch.testing.assert_close(pred1, pred2, rtol=rtol, atol=atol) + + def _compare_models( + self, + m1: DistributedModelParallel, + m2: DistributedModelParallel, + is_deterministic: bool = True, + ) -> None: + sd1 = m1.state_dict() + for key, value in m2.state_dict().items(): + v2 = sd1[key] + if isinstance(value, ShardedTensor): + assert isinstance(v2, ShardedTensor) + self.assertEqual(len(value.local_shards()), len(v2.local_shards())) + for dst, src in zip(value.local_shards(), v2.local_shards()): + if is_deterministic: + self.assertTrue(torch.equal(src.tensor, dst.tensor)) + else: + rtol, atol = _get_default_rtol_and_atol(src.tensor, dst.tensor) + torch.testing.assert_close( + src.tensor, dst.tensor, rtol=rtol, atol=atol + ) + elif isinstance(value, DTensor): + assert isinstance(v2, DTensor) + self.assertEqual( + len(value._local_tensor.local_shards()), # pyre-ignore[16] + len(v2._local_tensor.local_shards()), + ) + for dst, src in zip( + value._local_tensor.local_shards(), v2._local_tensor.local_shards() + ): + if is_deterministic: + self.assertTrue(torch.equal(src, dst)) + else: + rtol, atol = _get_default_rtol_and_atol(src, dst) + torch.testing.assert_close( + src._local_tensor, dst._local_tensor, rtol=rtol, atol=atol + ) + else: + dst = value + src = v2 + if is_deterministic: + self.assertTrue(torch.equal(src, dst)) + else: + rtol, atol = _get_default_rtol_and_atol(src, dst) + torch.testing.assert_close(src, dst, rtol=rtol, atol=atol) + + +class ModelParallelStateDictBase(ModelParallelSingleRankBase): + def _create_tables(self) -> None: + num_features = 4 + num_weighted_features = 2 + + self.tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + def test_parameter_init(self) -> None: + class MyModel(nn.Module): + def __init__(self, device: str, val: float) -> None: + super().__init__() + self.p = nn.Parameter( + torch.empty(3, dtype=torch.float32, device=device) + ) + self.val = val + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.constant_(self.p, self.val) + + # Check that already allocated parameters are left 'as is'. + unsharded_model = MyModel(device=self.device, val=3.2) + sharded_model = DistributedModelParallel( + unsharded_model, + device=self.device, + ) + sharded_param = next(sharded_model.parameters()) + np.testing.assert_array_equal( + np.array([3.2, 3.2, 3.2], dtype=np.float32), + sharded_param.detach().cpu().numpy(), + ) + + # Check that parameters over 'meta' device are allocated and initialized. + meta_model = MyModel(device="meta", val=7.5) + sharded_model = DistributedModelParallel( + meta_model, + device=self.device, + ) + sharded_param = next(sharded_model.parameters()) + np.testing.assert_array_equal( + np.array([7.5, 7.5, 7.5], dtype=np.float32), + sharded_param.detach().cpu().numpy(), + ) + + def test_meta_device_dmp_state_dict(self) -> None: + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env = ShardingEnv.from_process_group(dist.GroupMember.WORLD) + + m1 = self._create_model() + # dmp with real device + dmp1 = DistributedModelParallel( + module=m1, + init_data_parallel=False, + init_parameters=False, + sharders=get_default_sharders(), + device=self.device, + env=env, + plan=EmbeddingShardingPlanner( + topology=Topology( + world_size=env.world_size, compute_device=self.device.type + ) + ).plan(m1, get_default_sharders()), + ) + + m2 = self._create_model() + # dmp with meta device + dmp2 = DistributedModelParallel( + module=m2, + init_data_parallel=False, + init_parameters=False, + sharders=get_default_sharders(), + device=torch.device("meta"), + env=env, + plan=EmbeddingShardingPlanner( + topology=Topology( + world_size=env.world_size, compute_device=self.device.type + ) + ).plan(m2, get_default_sharders()), + ) + + sd1 = dmp1.state_dict() + for key, v2 in dmp2.state_dict().items(): + v1 = sd1[key] + if isinstance(v2, nn.parameter.UninitializedParameter) and isinstance( + v1, nn.parameter.UninitializedParameter + ): + continue + if isinstance(v2, ShardedTensor): + self.assertTrue(isinstance(v1, ShardedTensor)) + assert len(v2.local_shards()) == 1 + dst = v2.local_shards()[0].tensor + elif isinstance(v2, DTensor): + self.assertTrue(isinstance(v1, DTensor)) + assert len(v2._local_tensor.local_shards()) == 1 # pyre-ignore[16] + dst = v2._local_tensor.local_shards()[0] + else: + dst = v2 + if isinstance(v1, ShardedTensor): + assert len(v1.local_shards()) == 1 + src = v1.local_shards()[0].tensor + elif isinstance(v1, DTensor): + assert len(v1._local_tensor.local_shards()) == 1 + src = v1._local_tensor.local_shards()[0] + else: + src = v1 + self.assertEqual(src.size(), dst.size()) + + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_load_state_dict( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + is_training: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + ), + ), + ] + models, batch = self._generate_dmps_and_batch(sharders) + m1, m2 = models + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_optimizer_load_state_dict( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + }, + ), + ), + ] + models, batch = self._generate_dmps_and_batch(sharders) + m1, m2 = models + + # train m1 a bit, to make sure the optimizer state is not zero + self._train_models(m1, m1, batch) + # sync the state dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + # train both models, so they should diverage + self._train_models(m1, m2, batch) + # expect eval models to fail, since one model starts with non-zero optimizer state + with self.assertRaises(AssertionError): + self._eval_models(m1, m2, batch) + + # sync state dict again + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + # load state dict for optimizer as well + opt1 = m1.fused_optimizer + opt2 = m2.fused_optimizer + opt1.load_state_dict(opt2.state_dict()) + + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.DATA_PARALLEL.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + ] + ), + is_training=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_load_state_dict_dp( + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool + ) -> None: + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + ), + ), + ] + models, batch = self._generate_dmps_and_batch(sharders) + m1, m2 = models + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_load_state_dict_prefix( + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder(sharder_type, sharding_type, kernel_type), + ), + ] + (m1, m2), batch = self._generate_dmps_and_batch(sharders) + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", m1.state_dict(prefix="alpha")), + prefix="alpha", + ) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + # EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_params_and_buffers( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ] + # pyre-ignore[6] + (m, _), batch = self._generate_dmps_and_batch(sharders=sharders) + print(f"Sharding Plan: {m._plan}") + state_dict_keys = set(m.state_dict().keys()) + param_keys = set(dict(m.named_parameters()).keys()) + buffer_keys = set(dict(m.named_buffers()).keys()) + self.assertEqual(state_dict_keys, {*param_keys, *buffer_keys}) + + # pyre-ignore + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_load_state_dict_cw_multiple_shards( + self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + }, + ), + ), + ] + + constraints = defaultdict(lambda: trec_dist.planner.ParameterConstraints()) + num_cw_shards_per_table = {} + for table in self.tables + self.weighted_tables: + constraints[table.name].min_partition = 4 + num_cw_shards_per_table[table.name] = table.embedding_dim // 4 + + (m1, m2), batch = self._generate_dmps_and_batch( + sharders, constraints=constraints + ) + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + + # load optimizer state dict + + # Check to see that we can load optimizer state + src_optimizer = m1.fused_optimizer + dst_optimizer = m2.fused_optimizer + + src_optimizer_state_dict = src_optimizer.state_dict() + dst_optimizer_state_dict = dst_optimizer.state_dict() + m2.fused_optimizer.load_state_dict(src_optimizer_state_dict) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + + sd1 = m1.state_dict() + for key, value in m2.state_dict().items(): + if "." in key: + table_name = key.split(".")[-2] + v2 = sd1[key] + if isinstance(value, ShardedTensor): + self.assertEqual( + len(value.local_shards()), num_cw_shards_per_table[table_name] + ) + dst = value.local_shards()[0].tensor + elif isinstance(value, DTensor): + self.assertEqual( + len(value._local_tensor.local_shards()), # pyre-ignore[16] + num_cw_shards_per_table[table_name], + ) + dst = value._local_tensor.local_shards()[0] + else: + dst = value + + if isinstance(v2, ShardedTensor): + self.assertEqual( + len(value.local_shards()), num_cw_shards_per_table[table_name] + ) + + for src_local_shard, dst_local_shard in zip( + value.local_shards(), v2.local_shards() + ): + self.assertTrue( + torch.equal(src_local_shard.tensor, dst_local_shard.tensor) + ) + elif isinstance(v2, DTensor): + self.assertEqual( + len(value._local_tensor.local_shards()), + num_cw_shards_per_table[table_name], + ) + + for src_local_shard, dst_local_shard in zip( + value._local_tensor.local_shards(), + v2._local_tensor.local_shards(), + ): + self.assertTrue(torch.equal(src_local_shard, dst_local_shard)) + else: + src = v2 + self.assertTrue(torch.equal(src, dst)) + + for param_name, dst_param_group in dst_optimizer_state_dict.items(): + src_param_group = src_optimizer_state_dict[param_name] + + for state_key, dst_opt_state in dst_param_group.items(): + table_name = state_key.split(".")[-2] + src_opt_state = src_param_group[state_key] + if isinstance(dst_opt_state, ShardedTensor): + self.assertIsInstance(src_param_group[state_key], ShardedTensor) + + self.assertEqual( + len(dst_opt_state.local_shards()), + num_cw_shards_per_table[table_name], + ) + + self.assertEqual( + len(src_opt_state.local_shards()), + num_cw_shards_per_table[table_name], + ) + + for src_local_shard, dst_local_shard in zip( + src_opt_state.local_shards(), dst_opt_state.local_shards() + ): + self.assertTrue( + torch.equal(src_local_shard.tensor, dst_local_shard.tensor) + ) + elif isinstance(dst_opt_state, DTensor): + self.assertIsInstance(src_opt_state, DTensor) + + self.assertEqual( + len(dst_opt_state._local_tensor.local_shards()), + num_cw_shards_per_table[table_name], + ) + + self.assertEqual( + len(src_opt_state._local_tensor.local_shards()), + num_cw_shards_per_table[table_name], + ) + + for src_local_shard, dst_local_shard in zip( + src_opt_state._local_tensor.local_shards(), + dst_opt_state._local_tensor.local_shards(), + ): + self.assertTrue(torch.equal(src_local_shard, dst_local_shard)) + elif isinstance(dst_opt_state, torch.Tensor): + self.assertIsInstance(src_opt_state, torch.Tensor) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_numerical_equivalence_between_kernel_types( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + self._set_table_weights_precision(dtype) + fused_params = { + "stochastic_rounding": stochastic_rounding, + "cache_precision": dtype, + } + + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + EmbeddingComputeKernel.FUSED.value, + fused_params=fused_params, + ), + ), + ] + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (model, _), batch = self._generate_dmps_and_batch(sharders) + + # load the baseline model's state_dict onto the new model + model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + if is_training: + for _ in range(4): + self._train_models(fused_model, model, batch) + self._eval_models( + fused_model, model, batch, is_deterministic=not stochastic_rounding + ) + self._compare_models( + fused_model, model, is_deterministic=not stochastic_rounding + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_rowwise_adagrad_numerical_equivalence( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + ) -> None: + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + } + + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + EmbeddingComputeKernel.FUSED.value, + fused_params=fused_params, + ), + ), + ] + dense_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + EmbeddingComputeKernel.DENSE.value, + fused_params=fused_params, + ), + ), + ] + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (dense_model, _), batch = self._generate_dmps_and_batch(dense_sharders) + + dense_opt = RowWiseAdagrad( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `parameters`. + dense_model.module.sparse.parameters(), + lr=learning_rate, + eps=1e-8, # TBE has default eps 1e-8 + ) + + # load the baseline model's state_dict onto the new model + dense_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + for _ in range(4): + dense_opt.zero_grad() + loss1, pred1 = fused_model(batch) + loss2, pred2 = dense_model(batch) + loss1.backward() + loss2.backward() + dense_opt.step() + + self._eval_models(fused_model, dense_model, batch, is_deterministic=False) + self._compare_models(fused_model, dense_model, is_deterministic=False) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 61e031d57..e4a8469c2 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -5,20 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import random from enum import Enum -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Any, cast, Dict, List, Optional, Protocol, Tuple, Type, Union import torch import torch.distributed as dist import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) from torchrec.distributed.embedding_types import EmbeddingTableConfig from torchrec.distributed.fbgemm_qcomm_codec import ( CommType, get_qcomm_codecs_registry, QCommsConfig, ) -from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.model_parallel import DistributedModelParallel, DMPCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, ParameterConstraints, @@ -34,16 +41,21 @@ TestSparseNNBase, ) from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, ModuleSharder, ShardedTensor, ShardingEnv, ShardingPlan, ShardingType, ) -from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig -from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.modules.embedding_configs import ( + BaseEmbeddingConfig, + DataType, + EmbeddingBagConfig, +) +from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper -from typing_extensions import Protocol +from torchrec.optim.optimizers import in_backward_optimizer_filter class SharderType(Enum): @@ -60,7 +72,6 @@ def create_test_sharder( fused_params: Optional[Dict[str, Any]] = None, qcomms_config: Optional[QCommsConfig] = None, device: Optional[torch.device] = None, - variable_batch_size: bool = False, ) -> Union[TestEBSharder, TestEBCSharder, TestETSharder, TestETCSharder]: if fused_params is None: fused_params = {} @@ -79,7 +90,6 @@ def create_test_sharder( kernel_type, fused_params, qcomm_codecs_registry, - variable_batch_size, ) elif sharder_type == SharderType.EMBEDDING_TOWER.value: return TestETSharder( @@ -106,29 +116,28 @@ def __call__( Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]] ] = None, variable_batch_size: bool = False, - ) -> Tuple["ModelInput", List["ModelInput"]]: - ... + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + ) -> Tuple["ModelInput", List["ModelInput"]]: ... -def generate_inputs( - world_size: int, - tables: List[EmbeddingTableConfig], - dedup_tables: List[EmbeddingTableConfig], - generate: ModelInputCallable = ModelInput.generate, - weighted_tables: Optional[List[EmbeddingTableConfig]] = None, - batch_size: int = 4, - num_float_features: int = 16, - variable_batch_size: bool = False, -) -> Tuple[ModelInput, List[ModelInput]]: - return generate( - batch_size=batch_size, - world_size=world_size, - num_float_features=num_float_features, - tables=tables, - dedup_tables=dedup_tables, - weighted_tables=weighted_tables or [], - variable_batch_size=variable_batch_size, - ) +class VariableBatchModelInputCallable(Protocol): + def __call__( + self, + average_batch_size: int, + world_size: int, + num_float_features: int, + tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]], + weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]], + pooling_avg: int = 10, + global_constant_batch: bool = False, + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + ) -> Tuple["ModelInput", List["ModelInput"]]: ... def gen_model_and_input( @@ -136,7 +145,10 @@ def gen_model_and_input( tables: List[EmbeddingTableConfig], embedding_groups: Dict[str, List[str]], world_size: int, - generate: ModelInputCallable = ModelInput.generate, + # pyre-ignore [9] + generate: Union[ + ModelInputCallable, VariableBatchModelInputCallable + ] = ModelInput.generate, weighted_tables: Optional[List[EmbeddingTableConfig]] = None, num_float_features: int = 16, dense_device: Optional[torch.device] = None, @@ -145,6 +157,14 @@ def gen_model_and_input( dedup_tables: Optional[List[EmbeddingTableConfig]] = None, variable_batch_size: bool = False, batch_size: int = 4, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + global_constant_batch: bool = False, + num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -159,6 +179,7 @@ def gen_model_and_input( embedding_groups=embedding_groups, dense_device=dense_device, sparse_device=sparse_device, + feature_processor_modules=feature_processor_modules, ) else: model = model_class( @@ -171,28 +192,71 @@ def gen_model_and_input( embedding_groups=embedding_groups, dense_device=dense_device, sparse_device=sparse_device, + feature_processor_modules=feature_processor_modules, ) - inputs = [ - generate_inputs( - world_size=world_size, - tables=tables, - # pyre-ignore [6] - dedup_tables=dedup_tables, - generate=generate, - weighted_tables=weighted_tables, - num_float_features=num_float_features, - variable_batch_size=variable_batch_size, - batch_size=batch_size, - ) - ] + inputs = [] + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( + cast(VariableBatchModelInputCallable, generate)( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables or [], + global_constant_batch=global_constant_batch, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + input_type=input_type, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + ) + ) return (model, inputs) def copy_state_dict( - loc: Dict[str, Union[torch.Tensor, ShardedTensor]], + loc: Dict[str, Union[torch.Tensor, ShardedTensor, DTensor]], glob: Dict[str, torch.Tensor], exclude_predfix: Optional[str] = None, ) -> None: + """ + Copies the contents of the global tensors in glob to the local tensors in loc. + """ for name, tensor in loc.items(): if exclude_predfix is not None and name.startswith(exclude_predfix): continue @@ -201,9 +265,18 @@ def copy_state_dict( global_tensor = glob[name] if isinstance(global_tensor, ShardedTensor): global_tensor = global_tensor.local_shards()[0].tensor + if isinstance(global_tensor, DTensor): + # pyre-ignore[16] + global_tensor = global_tensor.to_local().local_shards()[0] + if isinstance(tensor, ShardedTensor): for local_shard in tensor.local_shards(): - assert global_tensor.ndim == local_shard.tensor.ndim + assert ( + global_tensor.ndim == local_shard.tensor.ndim + ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {local_shard.tensor.ndim}" + assert ( + global_tensor.dtype == local_shard.tensor.dtype + ), f"global tensor dtype: {global_tensor.dtype}, local tensor dtype: {local_shard.tensor.dtype}" shard_meta = local_shard.metadata t = global_tensor.detach() if t.ndim == 1: @@ -221,10 +294,43 @@ def copy_state_dict( else: raise ValueError("Tensors with ndim > 2 are not supported") local_shard.tensor.copy_(t) + elif isinstance(tensor, DTensor): + for local_shard, global_offset in zip( + tensor.to_local().local_shards(), + tensor.to_local().local_offsets(), # pyre-ignore[16] + ): + assert ( + global_tensor.ndim == local_shard.ndim + ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.ndim: {local_shard.ndim}" + assert ( + global_tensor.dtype == local_shard.dtype + ), f"global_tensor.dtype: {global_tensor.dtype}, local_shard.dtype: {local_shard.tensor.dtype}" + + t = global_tensor.detach() + local_shape = local_shard.shape + if t.ndim == 1: + t = t[global_offset[0] : global_offset[0] + local_shape[0]] + elif t.ndim == 2: + t = t[ + global_offset[0] : global_offset[0] + local_shape[0], + global_offset[1] : global_offset[1] + local_shape[1], + ] + else: + raise ValueError("Tensors with ndim > 2 are not supported") + local_shard.copy_(t) else: tensor.copy_(global_tensor) +# alter the ebc dtype to float32 in-place. +def alter_global_ebc_dtype(model: nn.Module) -> None: + for _name, ebc in model.named_modules(): + if isinstance(ebc, EmbeddingBagCollection) and ebc._is_weighted: + with torch.no_grad(): + for bag in ebc.embedding_bags.values(): + bag.weight = torch.nn.Parameter(bag.weight.float()) + + def sharding_single_rank_test( rank: int, world_size: int, @@ -241,21 +347,52 @@ def sharding_single_rank_test( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ] = None, - variable_batch_size: bool = False, + variable_batch_size: bool = False, # variable batch per rank batch_size: int = 4, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + variable_batch_per_feature: bool = False, # VBE + global_constant_batch: bool = False, + world_size_2D: Optional[int] = None, # 2D parallel + node_group_size: Optional[int] = None, # 2D parallel + use_inter_host_allreduce: bool = False, # 2D parallel + input_type: str = "kjt", # "kjt" or "td" + allow_zero_batch_size: bool = False, + custom_all_reduce: bool = False, # 2D parallel + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, ) -> None: - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + batch_size = ( + random.randint(0, batch_size) if allow_zero_batch_size else batch_size + ) # Generate model & inputs. (global_model, inputs) = gen_model_and_input( model_class=model_class, tables=tables, + # pyre-ignore [6] + generate=( + cast( + VariableBatchModelInputCallable, + ModelInput.generate_variable_batch_input, + ) + if variable_batch_per_feature + else ModelInput.generate + ), weighted_tables=weighted_tables, embedding_groups=embedding_groups, world_size=world_size, num_float_features=16, variable_batch_size=variable_batch_size, batch_size=batch_size, + feature_processor_modules=feature_processor_modules, + global_constant_batch=global_constant_batch, + input_type=input_type, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) @@ -269,6 +406,7 @@ def sharding_single_rank_test( dense_device=ctx.device, sparse_device=torch.device("meta"), num_float_features=16, + feature_processor_modules=feature_processor_modules, ) global_model_named_params_as_dict = dict(global_model.named_parameters()) @@ -280,20 +418,25 @@ def sharding_single_rank_test( optimizer_kwargs, ) in apply_optimizer_in_backward_config.items(): for name, param in global_model_named_params_as_dict.items(): - if name not in apply_optim_name: + if apply_optim_name not in name: continue assert name in local_model_named_params_as_dict local_param = local_model_named_params_as_dict[name] apply_optimizer_in_backward( - optimizer_type, [param], optimizer_kwargs + optimizer_type, + [param], + optimizer_kwargs, ) apply_optimizer_in_backward( optimizer_type, [local_param], optimizer_kwargs ) + # For 2D parallelism, we use single group world size and local world size planner = EmbeddingShardingPlanner( topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size + world_size=world_size_2D if world_size_2D else world_size, + compute_device=ctx.device.type, + local_world_size=node_group_size if node_group_size else ctx.local_size, ), constraints=constraints, ) @@ -308,14 +451,16 @@ def sharding_single_rank_test( TODO: may need to add some checks that only does this if we're running on a single GPU (which should be most cases). """ - for group in plan.plan: - for _, parameter_sharding in plan.plan[group].items(): + for _, parameter_sharding in cast( + EmbeddingModuleShardingPlan, plan.plan[group] + ).items(): if ( parameter_sharding.sharding_type in { ShardingType.TABLE_ROW_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.GRID_SHARD.value, } and ctx.device.type != "cpu" ): @@ -330,55 +475,142 @@ def sharding_single_rank_test( f"rank:{rank}/cuda:{rank}" ) - local_model = DistributedModelParallel( - local_model, - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) + assert ctx.pg is not None + hook_called: bool = False + if world_size_2D is not None: + all_reduce_func = None + if custom_all_reduce: + all_reduce_pg: dist.ProcessGroup = create_device_mesh_for_2D( + use_inter_host_allreduce, + world_size=ctx.world_size, + local_size=world_size_2D, + ).get_group(mesh_dim="replicate") + + def _custom_hook(input: List[torch.Tensor]) -> None: + nonlocal hook_called + opts = dist.AllreduceCoalescedOptions() + opts.reduceOp = dist.ReduceOp.AVG + handle = all_reduce_pg.allreduce_coalesced(input, opts=opts) + handle.wait() + hook_called = True + + all_reduce_func = _custom_hook + + local_model = DMPCollection( + module=local_model, + sharding_group_size=world_size_2D, + world_size=ctx.world_size, + global_pg=ctx.pg, # pyre-ignore[6] + node_group_size=node_group_size, + plan=plan, + sharders=sharders, + device=ctx.device, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=all_reduce_func, + ) + else: + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + device=ctx.device, + ) dense_optim = KeyedOptimizerWrapper( - dict(local_model.named_parameters()), + dict(in_backward_optimizer_filter(local_model.named_parameters())), lambda params: torch.optim.SGD(params, lr=0.1), ) local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) # Load model state from the global model. - copy_state_dict(local_model.state_dict(), global_model.state_dict()) + copy_state_dict( + local_model.state_dict(), + global_model.state_dict(), + exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", + ) + alter_global_ebc_dtype(global_model) # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step(local_model, local_opt, local_input) + local_pred = gen_full_pred_after_one_step( + local_model, + local_opt, + local_input, + ) - all_local_pred = [] - for _ in range(world_size): - all_local_pred.append(torch.empty_like(local_pred)) - dist.all_gather(all_local_pred, local_pred, group=ctx.pg) + if world_size_2D is not None and custom_all_reduce: + assert hook_called, "custom all reduce hook was not called" - # Run second training step of the unsharded model. - assert optim == EmbOptimType.EXACT_SGD - global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) + # TODO: support non-sharded forward with zero batch size KJT + if not allow_zero_batch_size: + all_local_pred = [] + for _ in range(world_size): + all_local_pred.append(torch.empty_like(local_pred)) + dist.all_gather(all_local_pred, local_pred, group=ctx.pg) - global_pred = gen_full_pred_after_one_step( - global_model, global_opt, global_input - ) + # Run second training step of the unsharded model. + assert optim == EmbOptimType.EXACT_SGD + global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) - # Compare predictions of sharded vs unsharded models. - if qcomms_config is None: - torch.testing.assert_close(global_pred, torch.cat(all_local_pred)) - else: - # With quantized comms, we can relax constraints a bit - rtol = 0.003 - if CommType.FP8 in [ - qcomms_config.forward_precision, - qcomms_config.backward_precision, - ]: - rtol = 0.05 - atol = global_pred.max().item() * rtol - torch.testing.assert_close( - global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol + global_pred = gen_full_pred_after_one_step( + global_model, global_opt, global_input ) + # Compare predictions of sharded vs unsharded models. + if qcomms_config is not None: + # With quantized comms, we can relax constraints a bit + rtol = 0.003 + if CommType.FP8 in [ + qcomms_config.forward_precision, + qcomms_config.backward_precision, + ]: + rtol = 0.05 + atol = global_pred.max().item() * rtol + torch.testing.assert_close( + global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol + ) + elif ( + weighted_tables is not None + and weighted_tables[0].data_type == DataType.FP16 + ): + # we relax this accuracy test because when the embedding table weights is FP16, + # the sharded EBC would upscale the precision to FP32 for the returned embedding + # KJT.weights (FP32) + sharded_EBC (FP16) ==> embeddings (FP32) + # the test uses the unsharded EBC for reference to compare the results, but the unsharded EBC + # uses EmbeddingBags can only handle same precision, i.e., + # KJT.weights (FP32) + unsharded_EBC (FP32) ==> embeddings (FP32) + # therefore, the discrepancy leads to a relaxed tol level. + torch.testing.assert_close( + global_pred, + torch.cat(all_local_pred), + atol=1e-4, # relaxed atol due to FP16 in weights + rtol=1e-4, # relaxed rtol due to FP16 in weights + ) + else: + torch.testing.assert_close(global_pred, torch.cat(all_local_pred)) + + +def create_device_mesh_for_2D( + use_inter_host_allreduce: bool, world_size: int, local_size: int +) -> DeviceMesh: + if use_inter_host_allreduce: + peer_matrix = [ + list(range(i, i + local_size)) for i in range(0, world_size, local_size) + ] + else: + peer_matrix = [] + step = world_size // local_size + for group_rank in range(world_size // local_size): + peer_matrix.append([step * r + group_rank for r in range(local_size)]) + + mesh = DeviceMesh( + device_type="cuda", + mesh=peer_matrix, + mesh_dim_names=("replicate", "shard"), + ) + + return mesh + def gen_full_pred_after_one_step( model: nn.Module, @@ -392,6 +624,10 @@ def gen_full_pred_after_one_step( loss.backward() opt.step() + # Sync embedding weights if 2D paralleism is used. + if isinstance(model, DMPCollection): + model.sync() + # Run a forward pass of the global model. with torch.no_grad(): model.train(False) diff --git a/torchrec/distributed/tests/collective_utils_test.py b/torchrec/distributed/tests/collective_utils_test.py index 679622a1f..2d49093a1 100644 --- a/torchrec/distributed/tests/collective_utils_test.py +++ b/torchrec/distributed/tests/collective_utils_test.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import multiprocessing import os import unittest diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py new file mode 100644 index 000000000..6af13a4ad --- /dev/null +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Any, cast, Dict, List, Optional, Tuple, Type + +import torch +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase +from torchrec.distributed.test_utils.test_model import TestSparseNNBase +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared +from torchrec.distributed.test_utils.test_sharding import ( + create_test_sharder, + SharderType, + sharding_single_rank_test, +) +from torchrec.distributed.tests.test_sequence_model import ( + TestEmbeddingCollectionSharder, + TestSequenceSparseNN, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingConfig, PoolingType +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class TestEmbeddingBagCollection2DParallel(ModelParallelTestShared): + """ + Tests for 2D parallelism of embeddingbagcollection tables + """ + + WORLD_SIZE = 8 + WORLD_SIZE_2D = 4 + + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), + custom_all_reduce=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_cw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + use_inter_host_allreduce: bool, + custom_all_reduce: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.COLUMN_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=custom_all_reduce, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), + custom_all_reduce=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_tw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + use_inter_host_allreduce: bool, + custom_all_reduce: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE_2D // 2, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=custom_all_reduce, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), + custom_all_reduce=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + use_inter_host_allreduce: bool, + custom_all_reduce: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.GRID_SHARD.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE // 4, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=custom_all_reduce, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), + custom_all_reduce=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_rw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + use_inter_host_allreduce: bool, + custom_all_reduce: bool, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=custom_all_reduce, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), + custom_all_reduce=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_twrw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + use_inter_host_allreduce: bool, + custom_all_reduce: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_ROW_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE // 4, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=custom_all_reduce, + ) + + +@skip_if_asan_class +class TestEmbeddingCollection2DParallel(MultiProcessTestBase): + """ + Tests for 2D parallelism of embeddingcollection tables + """ + + WORLD_SIZE = 8 + WORLD_SIZE_2D = 4 + + def setUp(self) -> None: + super().setUp() + + num_features = 4 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + ) + for i in range(shared_features) + ] + + self.tables = initial_tables + shared_features_tables + self.shared_features = [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups = { + "group_0": [ + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) + for table in self.tables + for feature in table.feature_names + ] + } + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.ROW_WISE.value), + kernel_type=st.sampled_from( + [ + # EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_ec_rw_2D( + self, + sharding_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + qcomms_config=qcomms_config, + ) + ], + backend="nccl", + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.COLUMN_WISE.value), + kernel_type=st.sampled_from( + [ + # EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_ec_cw_2D( + self, + sharding_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + qcomms_config=qcomms_config, + ) + ], + backend="nccl", + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.TABLE_WISE.value), + kernel_type=st.sampled_from( + [ + # EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_ec_tw_2D( + self, + sharding_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + qcomms_config=qcomms_config, + ) + ], + backend="nccl", + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + world_size_2D: int = 1, + local_size: Optional[int] = None, + node_group_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + world_size_2D=world_size_2D, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + ) diff --git a/torchrec/distributed/tests/test_apply_optim_per_param.py b/torchrec/distributed/tests/test_apply_optim_per_param.py index e28f606ff..ec388ab9f 100644 --- a/torchrec/distributed/tests/test_apply_optim_per_param.py +++ b/torchrec/distributed/tests/test_apply_optim_per_param.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import unittest from typing import Any, Dict, List, Optional @@ -93,6 +95,8 @@ def _test_sharding( plan: ShardingPlan = planner.collective_plan(model, [sharder], ctx.pg) sharded_model = DistributedModelParallel( module=model, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. env=ShardingEnv.from_process_group(ctx.pg), plan=plan, sharders=[sharder], @@ -200,7 +204,7 @@ class ShardedEmbeddingBagCollectionApplyOptimPerParamTest(MultiProcessTestBase): ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_ebc_per_parameter_optimizer( self, sharding_type: str, @@ -321,6 +325,8 @@ def _test_sharding_ec( plan: ShardingPlan = planner.collective_plan(model, [sharder], ctx.pg) sharded_model = DistributedModelParallel( module=model, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. env=ShardingEnv.from_process_group(ctx.pg), plan=plan, sharders=[sharder], @@ -416,7 +422,7 @@ class ShardedEmbeddingCollectionApplyOptimPerParamTest(MultiProcessTestBase): ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_ec_per_parameter_optimizer( self, sharding_type: str, diff --git a/torchrec/distributed/tests/test_apply_optimizer_to_dense_tbe.py b/torchrec/distributed/tests/test_apply_optimizer_to_dense_tbe.py new file mode 100644 index 000000000..3f4b468d5 --- /dev/null +++ b/torchrec/distributed/tests/test_apply_optimizer_to_dense_tbe.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import unittest + +import torch +from hypothesis import given, settings, strategies as st, Verbosity +from torch import distributed as dist +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec import distributed as trec_dist +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.sharding_plan import get_default_sharders +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSingleRankBase, +) +from torchrec.distributed.types import ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad + +logger: logging.Logger = logging.getLogger(__name__) + + +class ApplyOptmizerDenseTBETest(ModelParallelSingleRankBase): + def _create_tables(self) -> None: + num_features = 4 + + self.tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + fused_learning_rate=st.sampled_from([0, 0.1]), + non_fused_learning_rate=st.sampled_from([0, 0.1]), + dense_learning_rate=st.sampled_from([0, 0.1]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_apply_optimizer_to_dense_tbe( + self, + fused_learning_rate: float, + non_fused_learning_rate: float, + dense_learning_rate: float, + ) -> None: + unsharded_model = self._create_model() + + # torchrec sharding + constraints = { + table.name: ParameterConstraints( + sharding_types=[ + ( + ShardingType.TABLE_WISE.value + if i % 2 + else ShardingType.DATA_PARALLEL.value + ) + ], + compute_kernels=[ + ( + EmbeddingComputeKernel.FUSED.value + if i % 2 + else EmbeddingComputeKernel.DENSE.value + ) + ], + ) + for i, table in enumerate(self.tables) + } + sharders = get_default_sharders() + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=trec_dist.comm.get_local_size(env.world_size), + world_size=env.world_size, + compute_device=self.device.type, + ), + constraints=constraints, + ) + plan = planner.plan(unsharded_model, sharders) + + ### apply Rowwise Adagrad optimizer to fused TBEs ### + # fused TBE optimizer needs to be configured before initializing + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `named_parameters`. + for _, param in unsharded_model.sparse.named_parameters(): + apply_optimizer_in_backward( + RowWiseAdagrad, + [param], + {"lr": fused_learning_rate}, + ) + + # shard model + sharded_model = DistributedModelParallel( + module=unsharded_model, + init_data_parallel=True, + device=self.device, + sharders=sharders, + plan=plan, + ) + + ### apply Rowwise Adagrad optimizer to Data Parallel tables ### + # Optimizer for non fused tables need to be configured after initializing + non_fused_tables_optimizer = KeyedOptimizerWrapper( + dict( + in_backward_optimizer_filter( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `named_parameters`. + sharded_model.module.sparse.named_parameters() + ) + ), + lambda params: RowWiseAdagrad( + params, + lr=non_fused_learning_rate, + eps=1e-8, # to match with FBGEMM + ), + ) + + ### apply SGD to dense arch + over arch ### + dense_params = [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ] + dense_opt = torch.optim.Adagrad( + dense_params, + lr=dense_learning_rate, + ) + + # create input + _, local_batch = ModelInput.generate( + batch_size=self.batch_size, + world_size=env.world_size, + num_float_features=self.num_float_features, + tables=self.tables, + weighted_tables=[], + ) + batch = local_batch[0].to(self.device) + + # record signatures + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dhn_arch`. + dense_signature = sharded_model.module.over.dhn_arch.linear0.weight.sum().item() + non_fused_table_signature = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + sharded_model.module.sparse.ebc.embedding_bags.table_0.weight.sum().item() + ) + fused_table_signature = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + sharded_model.module.sparse.ebc.embedding_bags.table_1.weight.sum().item() + ) + + ### training ### + # zero grad + non_fused_tables_optimizer.zero_grad() + dense_opt.zero_grad() + + # forward and backward + loss, pred = sharded_model(batch) + loss.backward() + + # apply gradients + non_fused_tables_optimizer.step() + dense_opt.step() + + self.assertEqual( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + sharded_model.module.sparse.ebc.embedding_bags.table_1.weight.sum().item() + != fused_table_signature, + bool(fused_learning_rate), + ) + self.assertEqual( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + sharded_model.module.sparse.ebc.embedding_bags.table_0.weight.sum().item() + != non_fused_table_signature, + bool(non_fused_learning_rate), + ) + self.assertEqual( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dhn_arch`. + sharded_model.module.over.dhn_arch.linear0.weight.sum().item() + != dense_signature, + bool(dense_learning_rate), + ) diff --git a/torchrec/distributed/tests/test_awaitable.py b/torchrec/distributed/tests/test_awaitable.py index f86da0cf5..790a56b12 100644 --- a/torchrec/distributed/tests/test_awaitable.py +++ b/torchrec/distributed/tests/test_awaitable.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch @@ -22,6 +24,8 @@ def _wait_impl(self) -> torch.Tensor: class AwaitableTests(unittest.TestCase): def test_callback(self) -> None: awaitable = AwaitableInstance() + # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got + # `(ret: Any) -> int`. awaitable.callbacks.append(lambda ret: 2 * ret) self.assertTrue( torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0])) @@ -29,6 +33,8 @@ def test_callback(self) -> None: def test_callback_chained(self) -> None: awaitable = AwaitableInstance() + # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got + # `(ret: Any) -> int`. awaitable.callbacks.append(lambda ret: 2 * ret) awaitable.callbacks.append(lambda ret: ret**2) self.assertTrue( diff --git a/torchrec/distributed/tests/test_cache_prefetch.py b/torchrec/distributed/tests/test_cache_prefetch.py new file mode 100644 index 000000000..e309f29c5 --- /dev/null +++ b/torchrec/distributed/tests/test_cache_prefetch.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import unittest +from typing import cast, List + +import hypothesis.strategies as st + +import torch +import torch.nn as nn +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) +from hypothesis import given, settings, Verbosity +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection + +from torchrec.distributed.test_utils.test_model import TestEBCSharder +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import get_free_port, init_distributed_single_host + +SHARDING_TYPES: List[str] = [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, +] + + +class ShardedEmbeddingModuleCachePrefetchTest(unittest.TestCase): + def setUp(self) -> None: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + self.backend = "nccl" + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + else: + self.device = torch.device("cpu") + + self.pg = init_distributed_single_host( + backend=self.backend, rank=0, world_size=1 + ) + + self.embedding_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=40, + ), + ] + self.model = EmbeddingBagCollection( + tables=self.embedding_config, + device=self.device, + ) + + def get_cache_unique_misses( + self, emb_module: SplitTableBatchedEmbeddingBagsCodegen + ) -> int: + (_, _, _, num_unique_misses, _, _) = ( + emb_module.get_uvm_cache_stats(use_local_cache=True).detach().cpu().tolist() + ) + return num_unique_misses + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from(SHARDING_TYPES), + cache_load_factor=st.sampled_from([0.5, 0.8]), + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_sharded_ebc_cache_prefetch( + self, + sharding_type: str, + cache_load_factor: float, + ) -> None: + batch_0_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.LongTensor([1, 2, 3]), + lengths=torch.LongTensor([2, 0, 1]), + ).to(self.device) + + batch_1_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.LongTensor([10, 11, 12]), + lengths=torch.LongTensor([2, 0, 1]), + ).to(self.device) + + batch_2_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.LongTensor([30, 31, 33]), + lengths=torch.LongTensor([1, 1, 1]), + ).to(self.device) + + fused_params = { + "prefetch_pipeline": True, + "gather_uvm_cache_stats": True, + "cache_load_factor": cache_load_factor, + } + + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=self.model, + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ) + sharded_ebc = sharded_model.module + self.assertIsInstance(sharded_ebc, ShardedEmbeddingBagCollection) + lookups = sharded_ebc._lookups + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + emb_module = lookups[0]._emb_modules[0]._emb_module + self.assertIsInstance(emb_module, SplitTableBatchedEmbeddingBagsCodegen) + + # Embedding lookup without prior prefetch + sharded_ebc(batch_0_kjt) + + # We should have 3 unique misses since nothing was stored in cache yet + self.assertEqual(self.get_cache_unique_misses(emb_module), 3) + + kjt_list = KJTList([batch_1_kjt]) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + sharded_ebc.prefetch(kjt_list) + + # Reset cache stats so that local uvm cache stats are reset + # otherwise, the number of cache misses will be non-zero after the forward pass + emb_module.reset_uvm_cache_stats() + + # Embedding lookup will not prefetch here because prefetch() was called previously + sharded_ebc(batch_1_kjt) + + # No unique misses expected since we prefetched all indices for batch 1 + self.assertEqual(self.get_cache_unique_misses(emb_module), 0) + + # Do forward pass w/ indices that were prefetched previously + sharded_ebc(batch_1_kjt) + + # No unique misses expected since indices have been prefetched + self.assertEqual(self.get_cache_unique_misses(emb_module), 0) + + # Try fetching indices different than the ones prefetched previously + sharded_ebc(batch_2_kjt) + + # Unique cache misses should be 3 since indices requested by batch 2 are not presented in the cache + self.assertEqual(self.get_cache_unique_misses(emb_module), 3) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from(SHARDING_TYPES), + cache_load_factor=st.sampled_from([0.5, 0.8]), + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_sharded_ebc_cache_purge( + self, + sharding_type: str, + cache_load_factor: float, + ) -> None: + batch_1_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.LongTensor([5, 6, 7]), + lengths=torch.LongTensor([2, 0, 1]), + ).to(self.device) + + fused_params = { + "prefetch_pipeline": True, + "gather_uvm_cache_stats": True, + "cache_load_factor": cache_load_factor, + } + + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=self.model, + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ) + sharded_ebc = sharded_model.module + lookups = sharded_ebc._lookups + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + emb_module = lookups[0]._emb_modules[0]._emb_module + + kjt_list = KJTList([batch_1_kjt]) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + sharded_ebc.prefetch(kjt_list) + + # Reset cache stats so that local uvm cache stats are reset + # otherwise, the number of cache misses will be non-zero after the forward pass + emb_module.reset_uvm_cache_stats() + + # No prefetch called here + sharded_ebc(batch_1_kjt) + self.assertEqual(self.get_cache_unique_misses(emb_module), 0) + + # Implicitly call cache purge by invoking pre-load_state_dict hook + sharded_ebc.load_state_dict(sharded_ebc.state_dict()) + + sharded_ebc(batch_1_kjt) + + # We should have 3 unique misses since we purged the cache after the first lookup + self.assertEqual(self.get_cache_unique_misses(emb_module), 3) diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py index b3fe3ffe7..d110e9740 100644 --- a/torchrec/distributed/tests/test_comm.py +++ b/torchrec/distributed/tests/test_comm.py @@ -5,18 +5,160 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import functools import itertools import multiprocessing import os import unittest -from typing import Callable +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import hypothesis.strategies as st -import numpy import torch import torch.distributed as dist +import torchrec import torchrec.distributed.comm_ops as comm_ops +from hypothesis import given, settings +from torch.distributed.distributed_c10d import GroupMember from torchrec.test_utils import get_free_port, seed_and_log +torch.ops.import_module("fbgemm_gpu.sparse_ops") + + +@dataclass +class _CompileConfig: + # backend is None means no compilation + backend: Optional[str] = "inductor" + fullgraph: bool = True + skip_sync_backward: bool = False + skip_compile_backward: bool = False + test_compiled_with_noncompiled_ranks: bool = False + + +def compile_config_to_fn_transform( + compile_config: Optional[_CompileConfig], + # pyre-ignore +) -> Callable: + if compile_config is None: + return lambda x: x + + return functools.partial( + torch.compile, + backend=compile_config.backend, + fullgraph=compile_config.fullgraph, + dynamic=True, + ) + + +# pyre-ignore +def _copy_input_tensors(t, device): + if isinstance(t, torch.Tensor): + ret = t.detach().clone().to(device) + ret.requires_grad = True + ret.retain_grad() + return ret + elif isinstance(t, list): + return [_copy_input_tensors(_t, device) for _t in t] + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _grad_detach_clone(t): + if isinstance(t, torch.Tensor): + # pyre-ignore + return t.grad.detach().clone() + elif isinstance(t, list): + return [_grad_detach_clone(_t) for _t in t] + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _assert_close(actual, expected) -> None: + if isinstance(expected, torch.Tensor): + assert isinstance(actual, torch.Tensor) + torch.testing.assert_close(actual, expected) + elif isinstance(expected, list): + assert isinstance(actual, list) + for _a, _e in zip(actual, expected): + _assert_close(_a, _e) + else: + raise ValueError(f"Unsupported type {type(expected)}") + + +def _test_async_sync_compile( + # pyre-ignore + fn, + input_tensor: Union[torch.Tensor, List[torch.Tensor]], + device: torch.device, + compile_config: _CompileConfig, + rank: int, + # pyre-ignore + *args, + # pyre-ignore + **kwargs, +) -> None: + input_tensor_async = _copy_input_tensors(input_tensor, device) + input_tensor_sync = _copy_input_tensors(input_tensor, device) + input_tensor_compile = _copy_input_tensors(input_tensor, device) + + # Async + torchrec.distributed.comm_ops.set_use_sync_collectives(False) + out = fn(input_tensor_async, *args, **kwargs) + out.retain_grad() + out.backward(out) + async_fwd_out = out.clone() + async_bwd_out = _grad_detach_clone(input_tensor_async) + + # Sync + torchrec.distributed.comm_ops.set_use_sync_collectives(True) + out = fn(input_tensor_sync, *args, **kwargs) + sync_fwd_out = out.clone() + _assert_close(sync_fwd_out, async_fwd_out) + + if not compile_config.skip_sync_backward: + out.retain_grad() + out.backward(out) + sync_bwd_out = _grad_detach_clone(input_tensor_sync) + _assert_close(sync_bwd_out, async_bwd_out) + + if compile_config.backend is not None: + fn_transform = compile_config_to_fn_transform(compile_config) + + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + if compile_config.test_compiled_with_noncompiled_ranks and rank == 1: + # Turn off compilation for rank==1 to test compatibility of compiled rank and non-compiled + fn_transform = lambda x: x + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + out = fn_transform(fn)( + input_tensor_compile, + *args, + **kwargs, + ) + compile_fwd_out = out.clone() + _assert_close(compile_fwd_out, sync_fwd_out) + + if ( + not compile_config.skip_sync_backward + and not compile_config.skip_compile_backward + ): + out.retain_grad() + out.backward(out) + compile_bwd_out = _grad_detach_clone(input_tensor_compile) + + # pyre-ignore + _assert_close(compile_bwd_out, sync_bwd_out) + class TestAllToAll(unittest.TestCase): @seed_and_log @@ -37,6 +179,10 @@ def _run_multi_process_test( world_size: int, backend: str, callable: Callable[[], None], + # pyre-ignore + *args, + # pyre-ignore + **kwargs, ) -> None: processes = [] ctx = multiprocessing.get_context("spawn") @@ -47,7 +193,9 @@ def _run_multi_process_test( rank, world_size, backend, + *args, ), + kwargs=kwargs, ) p.start() processes.append(p) @@ -56,98 +204,60 @@ def _run_multi_process_test( p.join() self.assertEqual(0, p.exitcode) - @classmethod - def _test_alltoallv( - cls, - rank: int, - world_size: int, - backend: str, - ) -> None: - dist.init_process_group(rank=rank, world_size=world_size, backend=backend) - device = torch.device(f"cuda:{rank}") - - torch.cuda.set_device(device) - - B_global = 10 - D0 = 8 - D1 = 9 - - input_embedding0 = torch.rand( - (B_global, D0), - device=device, - requires_grad=True, - ) - input_embedding1 = torch.rand( - (B_global, D1), - device=device, - requires_grad=True, - ) - - input_embeddings = [input_embedding0, input_embedding1] - out_split = [17, 17] - - a2a_req = comm_ops.alltoallv(input_embeddings, out_split) - v_embs_out = a2a_req.wait() - res = torch.cat(v_embs_out, dim=1).cpu() - assert tuple(res.size()) == (5, 34) - dist.destroy_process_group() - - # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `torch.cuda.device_count() = 0` to decorator factory `unittest.skipIf`. - @unittest.skipIf( - torch.cuda.device_count() < 2, "Need at least two ranks to run this test" - ) - def test_alltoallv(self) -> None: - self._run_multi_process_test( - world_size=self.WORLD_SIZE, - backend="nccl", - # pyre-ignore [6] - callable=self._test_alltoallv, - ) - @classmethod def _test_alltoall_sequence( cls, rank: int, world_size: int, backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + skip_dynamo_backwards: bool = False, ) -> None: dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) ranks = 2 tables_mp = [[0], [1, 2]] lengths_dp = [ - numpy.array([[1, 2], [1, 1], [2, 1]]), - numpy.array([[1, 2], [2, 1], [3, 1]]), + torch.tensor([[1, 2], [1, 1], [2, 1]], dtype=torch.int), + torch.tensor([[1, 2], [2, 1], [3, 1]], dtype=torch.int), ] # W, T_g, B_l lengths_a2a = [ - numpy.array([[[1, 2]], [[1, 2]]]), # Rank 0 - numpy.array( + torch.tensor([[[1, 2]], [[1, 2]]], dtype=torch.int), # Rank 0 + torch.tensor( [ [[1, 1], [2, 1]], # from Rank 0 [[2, 1], [3, 1]], # from rank 1 - ] + ], + dtype=torch.int, ), # Rank 1 ] # W, W, T_l, B_l lengths_mp = [ - numpy.array( + torch.tensor( [ [1, 2, 1, 2], - ] + ], + dtype=torch.int, ), - numpy.array([[1, 1, 2, 1], [2, 1, 3, 1]]), + torch.tensor([[1, 1, 2, 1], [2, 1, 3, 1]], dtype=torch.int), ] # w, t_l, b_g input_seg = list(itertools.accumulate([0] + [len(i) for i in tables_mp])) input_splits = [ [ - lengths_dp[i][input_seg[j] : input_seg[j + 1], :].sum() + int(lengths_dp[i][input_seg[j] : input_seg[j + 1], :].sum()) for j in range(ranks) ] for i in range(ranks) ] - output_splits = [lengths_a2a[i].sum(axis=(1, 2)).tolist() for i in range(ranks)] + output_splits = [lengths_a2a[i].sum(dim=(1, 2)).tolist() for i in range(ranks)] table_dim = 3 num_features_per_rank = [len(features) for features in tables_mp] seq_all2all_forward_recat = [] @@ -162,36 +272,447 @@ def _test_alltoall_sequence( seq_all2all_backward_recat_tensor = torch.IntTensor(seq_all2all_backward_recat) input_embeddings = torch.rand( - lengths_mp[rank].sum(), + int(lengths_mp[rank].sum()), table_dim, device=device, requires_grad=True, ) lengths_after_sparse_data_all2all = torch.IntTensor(lengths_mp[rank]) - a2a_req = comm_ops.alltoall_sequence( - a2a_sequence_embs_tensor=input_embeddings.cuda(), + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.alltoall_sequence(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, + input_embeddings, + device, + compile_config, + rank, forward_recat_tensor=seq_all2all_forward_recat_tensor.cuda(), backward_recat_tensor=seq_all2all_backward_recat_tensor.cuda(), lengths_after_sparse_data_all2all=lengths_after_sparse_data_all2all.cuda(), input_splits=input_splits[rank], output_splits=output_splits[rank], + group=pg if specify_pg else None, ) - seq_embs_out = a2a_req.wait() - seq_embs_out.backward(seq_embs_out) - grad = input_embeddings.grad - # pyre-fixme[16]: Optional type has no attribute `cpu`. - assert torch.equal(input_embeddings.cpu().detach(), grad.cpu().detach()) dist.destroy_process_group() - # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `torch.cuda.device_count() = 0` to decorator factory `unittest.skipIf`. @unittest.skipIf( torch.cuda.device_count() < 2, "Need at least two ranks to run this test" ) - def test_alltoall_sequence(self) -> None: + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_alltoall_sequence( + self, + specify_pg: bool, + gradient_division: bool, + ) -> None: self._run_multi_process_test( world_size=self.WORLD_SIZE, backend="nccl", # pyre-ignore [6] callable=self._test_alltoall_sequence, + compile_config=_CompileConfig(), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_alltoall_pooled( + cls, + rank: int, + world_size: int, + backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + pg = dist.distributed_c10d._get_default_group() + + # Each rank's local batch size + + batch_size_per_rank = [4] * world_size + + # Global batch size is the sum of all rank's local batch size + B_global = sum(batch_size_per_rank) + # sum of dimensions of the embedding tables hosted on each rank + dim_sum_per_rank = [8] * world_size + + D_local_sum = dim_sum_per_rank[rank] + + # Construct pooled embeddings + pooled_embs = torch.randn([B_global, D_local_sum], requires_grad=True).to( + device + ) + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.alltoall_pooled(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, + pooled_embs, + device, + compile_config, + rank, + batch_size_per_rank, + dim_sum_per_rank, + pg, + ) + + dist.destroy_process_group() + + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_alltoall_pooled( + self, + specify_pg: bool, + test_compiled_with_noncompiled_ranks: bool, + gradient_division: bool, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="nccl", + # pyre-ignore [6] + callable=self._test_alltoall_pooled, + compile_config=_CompileConfig( + test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks + ), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_reduce_scatter_pooled( + cls, + rank: int, + world_size: int, + backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + pg = dist.distributed_c10d._get_default_group() + + batch_size_per_rank = [4] * world_size + B_global = sum(batch_size_per_rank) + dim_sum_per_rank = [8] * world_size + + D_local_sum = dim_sum_per_rank[rank] + + inputs: List[torch.Tensor] = [] + for _ in range(world_size): + input = torch.randn([B_global, D_local_sum], requires_grad=True).to(device) + input.retain_grad() + inputs.append(input) + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.reduce_scatter_pooled(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, + inputs, + device, + compile_config, + rank, + pg if specify_pg else None, + ) + + dist.destroy_process_group() + + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_reduce_scatter_pooled( + self, + specify_pg: bool, + test_compiled_with_noncompiled_ranks: bool, + gradient_division: bool, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="nccl", + # pyre-ignore [6] + callable=self._test_reduce_scatter_pooled, + compile_config=_CompileConfig( + test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks + ), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_reduce_scatter_v_pooled( + cls, + rank: int, + world_size: int, + backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + pg = dist.distributed_c10d._get_default_group() + + src: List[int] = [1, 2, 3] * world_size + input_splits: List[int] = src[:world_size] + inputs_dim: int = sum(input_splits) + + input: torch.Tensor = torch.randn(inputs_dim, 2, requires_grad=True).to(device) + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.reduce_scatter_v_pooled(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, + input, + device, + compile_config, + rank, + input_splits, + pg if specify_pg else None, + ) + + dist.destroy_process_group() + + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_reduce_scatter_v_pooled( + self, + specify_pg: bool, + test_compiled_with_noncompiled_ranks: bool, + gradient_division: bool, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="nccl", + # pyre-ignore [6] + callable=self._test_reduce_scatter_v_pooled, + compile_config=_CompileConfig( + test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks + ), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_reduce_scatter_v_per_feature_pooled( + cls, + rank: int, + world_size: int, + backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + pg = dist.distributed_c10d._get_default_group() + + batch_size_per_feature: List[int] = [2, 4, 4, 7, 2] + batch_size_per_rank_per_feature: List[List[int]] = [] + for _ in range(world_size): + batch_size_per_rank_per_feature.append(batch_size_per_feature) + + embedding_dims: List[int] = [12] * len(batch_size_per_feature) + + n = world_size * sum( + [b * emb_dim for b, emb_dim in zip(batch_size_per_feature, embedding_dims)] + ) + input: torch.Tensor = torch.randn(n, requires_grad=True).to(device) + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.reduce_scatter_v_per_feature_pooled(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, + input, + device, + compile_config, + rank, + batch_size_per_rank_per_feature, + embedding_dims, + pg if specify_pg else None, + ) + dist.destroy_process_group() + + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_reduce_scatter_v_per_feature_pooled( + self, + specify_pg: bool, + test_compiled_with_noncompiled_ranks: bool, + gradient_division: bool, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="nccl", + # pyre-ignore [6] + callable=self._test_reduce_scatter_v_per_feature_pooled, + compile_config=_CompileConfig( + test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks + ), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_all_gather_base_pooled( + cls, + rank: int, + world_size: int, + backend: str, + compile_config: _CompileConfig, + specify_pg: bool, + gradient_division: bool, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + pg = dist.distributed_c10d._get_default_group() + + input = torch.randn([4, 4], requires_grad=True).to(device) + + # pyre-ignore + def fn(*args, **kwargs) -> torch.Tensor: + return comm_ops.all_gather_base_pooled(*args, **kwargs).wait() + + comm_ops.set_gradient_division(gradient_division) + _test_async_sync_compile( + fn, input, device, compile_config, rank, pg if specify_pg else None + ) + + dist.destroy_process_group() + + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + # pyre-ignore + @given( + specify_pg=st.sampled_from([True]), + test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), + gradient_division=st.sampled_from([True, False]), + ) + @settings(deadline=None) + def test_all_gather_base_pooled( + self, + specify_pg: bool, + test_compiled_with_noncompiled_ranks: bool, + gradient_division: bool, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="nccl", + # pyre-ignore [6] + callable=self._test_all_gather_base_pooled, + compile_config=_CompileConfig( + test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks + ), + specify_pg=specify_pg, + gradient_division=gradient_division, + ) + + @classmethod + def _test_all_gather_base_pooled_cpu( + cls, + rank: int, + world_size: int, + backend: str, + ) -> None: + pg = GroupMember.WORLD + if pg is None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = GroupMember.WORLD + + device = torch.device(f"cpu") + input_tensor = torch.randn([4, 4], requires_grad=True).to(device) + comm_ops.all_gather_base_pooled(input_tensor, pg).wait() + dist.destroy_process_group() + + def test_all_gather_base_pooled_cpu( + self, + ) -> None: + self._run_multi_process_test( + world_size=self.WORLD_SIZE, + backend="gloo", + # pyre-ignore [6] + callable=self._test_all_gather_base_pooled_cpu, ) diff --git a/torchrec/distributed/tests/test_dist_data.py b/torchrec/distributed/tests/test_dist_data.py index f31f37e50..38c30f47d 100644 --- a/torchrec/distributed/tests/test_dist_data.py +++ b/torchrec/distributed/tests/test_dist_data.py @@ -5,26 +5,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import itertools import random import unittest -from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union +from typing import Dict, Generator, Iterable, List, Optional, Tuple, TypeVar, Union import hypothesis.strategies as st -import numpy as np import torch import torch.distributed as dist from hypothesis import given, settings -# @manual=//python/wheel/numpy:numpy -from numpy.testing import assert_array_equal from torchrec.distributed.dist_data import ( + _get_recat, + JaggedTensorAllToAll, KJTAllToAll, - KJTAllToAllLengthsAwaitable, + KJTAllToAllSplitsAwaitable, PooledEmbeddingsAllGather, PooledEmbeddingsAllToAll, PooledEmbeddingsReduceScatter, SequenceEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, ) from torchrec.distributed.fbgemm_qcomm_codec import ( CommType, @@ -32,14 +34,18 @@ QCommsConfig, ) -from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + +T = TypeVar("T", int, float, List[int]) -T = TypeVar("T", int, float) # Lightly adapted from Stack Overflow #10823877 -def _flatten(iterable: List[T]) -> Generator[T, None, None]: +def _flatten(iterable: Iterable[T]) -> Generator[T, None, None]: iterator, sentinel, stack = iter(iterable), object(), [] while True: value = next(iterator, sentinel) @@ -97,9 +103,11 @@ def _generate_sparse_features_batch( keys=keys, lengths=_to_tensor([lengths[key][i] for key in keys], torch.int), values=_to_tensor([values[key][i] for key in keys], torch.int), - weights=_to_tensor([weights[key][i] for key in keys], torch.float) - if weights - else None, + weights=( + _to_tensor([weights[key][i] for key in keys], torch.float) + if weights + else None + ), ) ) key_index = [] @@ -118,12 +126,97 @@ def _generate_sparse_features_batch( [values[key][j] for key, j in key_index], torch.int, ), - weights=_to_tensor( - [weights[key][j] for key, j in key_index], - torch.float, - ) - if weights - else None, + weights=( + _to_tensor( + [weights[key][j] for key, j in key_index], + torch.float, + ) + if weights + else None + ), + ) + ) + return in_jagged, out_jagged + + +def _generate_variable_batch_sparse_features_batch( + keys: List[str], + splits: List[int], + batch_size_per_rank_per_feature: List[List[List[int]]], + is_weighted: bool = False, +) -> Tuple[List[KeyedJaggedTensor], List[KeyedJaggedTensor]]: + world_size = len(splits) + offsets = [0] + list(itertools.accumulate(splits)) + values = {} + lengths = {} + weights = {} if is_weighted else None + + for i, key in enumerate(keys): + lengths[key] = [ + [ + random.randint(0, 10) + for _ in range(sum(batch_size_per_rank_per_feature[rank][i])) + ] + for rank in range(world_size) + ] + values[key] = [ + [random.randint(0, 1000) for _ in range(sum(lengths[key][j]))] + for j in range(world_size) + ] + + if weights: + weights[key] = [ + [random.random() for _ in range(sum(lengths[key][j]))] + for j in range(world_size) + ] + + in_jagged: List[KeyedJaggedTensor] = [] + out_jagged: List[KeyedJaggedTensor] = [] + for i in range(world_size): + in_jagged.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + stride_per_key_per_rank=batch_size_per_rank_per_feature[i], + lengths=_to_tensor([lengths[key][i] for key in keys], torch.int), + values=_to_tensor([values[key][i] for key in keys], torch.int), + weights=( + _to_tensor([weights[key][i] for key in keys], torch.float) + if weights + else None + ), + ) + ) + key_index = [] + out_keys = keys[offsets[i] : offsets[i + 1]] + key_indices = [keys.index(k) for k in out_keys] + batch_sizes_by_rank = list(zip(*batch_size_per_rank_per_feature)) + for key in out_keys: + for j in range(world_size): + key_index.append((key, j)) + + out_jagged.append( + KeyedJaggedTensor.from_lengths_sync( + keys=out_keys, + stride_per_key_per_rank=[ + list(_flatten(batch_sizes_by_rank[key_idx])) + for key_idx in key_indices + ], + lengths=_to_tensor( + [lengths[key][j] for key, j in key_index], + torch.int, + ), + values=_to_tensor( + [values[key][j] for key, j in key_index], + torch.int, + ), + weights=( + _to_tensor( + [weights[key][j] for key, j in key_index], + torch.float, + ) + if weights + else None + ), ) ) return in_jagged, out_jagged @@ -173,10 +266,8 @@ class KJTAllToAllTest(MultiProcessTestBase): @classmethod def _validate( cls, - actual_output_awaitable: Union[KJTAllToAllLengthsAwaitable, KeyedJaggedTensor], - expected_output_awaitable: Union[ - KJTAllToAllLengthsAwaitable, KeyedJaggedTensor - ], + actual_output_awaitable: Union[KJTAllToAllSplitsAwaitable, KeyedJaggedTensor], + expected_output_awaitable: Union[KJTAllToAllSplitsAwaitable, KeyedJaggedTensor], ) -> None: actual_output = ( actual_output_awaitable @@ -188,26 +279,27 @@ def _validate( if isinstance(expected_output_awaitable, KeyedJaggedTensor) else expected_output_awaitable.wait().wait() ) - assert_array_equal( + torch.testing.assert_close( actual_output.values().cpu(), expected_output.values().cpu(), ) - assert_array_equal( - actual_output.weights().cpu() - if actual_output.weights_or_none() is not None - else [], - expected_output.weights().cpu() - if expected_output.weights_or_none() is not None - else [], + torch.testing.assert_close( + ( + actual_output.weights().cpu() + if actual_output.weights_or_none() is not None + else [] + ), + ( + expected_output.weights().cpu() + if expected_output.weights_or_none() is not None + else [] + ), ) - assert_array_equal( + torch.testing.assert_close( actual_output.lengths().cpu(), expected_output.lengths().cpu(), ) - assert_array_equal( - actual_output.keys(), - expected_output.keys(), - ) + assert actual_output.keys() == expected_output.keys() @classmethod def _run_test_dist( @@ -218,7 +310,6 @@ def _run_test_dist( output: KeyedJaggedTensor, backend: str, splits: List[int], - batch_size_per_rank: List[int], ) -> None: dist.init_process_group(rank=rank, world_size=world_size, backend=backend) device = torch.device(f"cuda:{rank}") @@ -232,8 +323,6 @@ def _run_test_dist( # `Optional[_distributed_c10d.ProcessGroup]`. pg=pg, splits=splits, - device=device, - variable_batch_size=len(set(batch_size_per_rank)) > 1, ) cls._validate(lengths_a2a(_input), output) dist.destroy_process_group() @@ -244,13 +333,13 @@ def _run_test_dist( ) # pyre-fixme[56] @given( - backend=st.sampled_from(["gloo", "nccl"]), + backend=st.sampled_from(["nccl"]), B=st.integers(min_value=1, max_value=2), features=st.integers(min_value=3, max_value=4), is_weighted=st.booleans(), variable_batch_size=st.booleans(), ) - @settings(max_examples=8, deadline=None) + @settings(max_examples=4, deadline=None) def test_features( self, backend: str, @@ -284,7 +373,66 @@ def test_features( "output": output[rank], "backend": backend, "splits": splits, - "batch_size_per_rank": batch_size_per_rank, + } + ) + + self._run_multi_process_test_per_rank( + callable=self._run_test_dist, + world_size=world_size, + kwargs_per_rank=kwargs_per_rank, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + backend=st.sampled_from(["nccl"]), + B=st.integers(min_value=1, max_value=2), + features=st.integers(min_value=3, max_value=4), + is_weighted=st.booleans(), + variable_batch_per_rank=st.booleans(), + ) + @settings(max_examples=4, deadline=None) + def test_variable_batch_features( + self, + backend: str, + B: int, + features: int, + is_weighted: bool, + variable_batch_per_rank: bool, + ) -> None: + keys = [f"F{feature}" for feature in range(features)] + rank0_split = random.randint(0, features) + splits = [rank0_split, features - rank0_split] + world_size = 2 + + if variable_batch_per_rank: + batch_size_per_rank_per_feature = [ + [[random.randint(B, B + 4)] for _ in range(features)] + for _ in range(world_size) + ] + else: + batch_size_per_rank_per_feature = [ + [[random.randint(B, B + 4)] for _ in range(features)] + ] * world_size + + _input, output = _generate_variable_batch_sparse_features_batch( + keys=keys, + splits=splits, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature, + is_weighted=is_weighted, + ) + + kwargs_per_rank = [] + for rank in range(world_size): + kwargs_per_rank.append( + { + "_input": _input[rank], + "output": output[rank], + "backend": backend, + "splits": splits, } ) @@ -397,7 +545,7 @@ def _run_test_dist( ] ), ) - @settings(max_examples=8, deadline=None) + @settings(max_examples=4, deadline=None) def test_pooled_embeddings( self, backend: str, @@ -484,7 +632,7 @@ def _run_test_dist( atol=atol, ) if qcomms_config is None: - assert_array_equal( + torch.testing.assert_close( # pyre-ignore input.grad.cpu().detach(), torch.ones(input.size()).div_(world_size), @@ -494,7 +642,6 @@ def _run_test_dist( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - @settings(deadline=30000) # pyre-ignore @given( qcomms_config=st.sampled_from( @@ -538,17 +685,18 @@ def _run_test_dist( ] ), ) + @settings(max_examples=3, deadline=45000) def test_pooled_embedding_reduce_scatter( self, qcomms_config: Optional[QCommsConfig] ) -> None: world_size = 2 embeddding_dim = 10 - batch_size = 2 + batch_size = 4 embeddings = torch.rand((batch_size * world_size, embeddding_dim)) - embeddings_by_rank = list(torch.chunk(embeddings, batch_size, dim=0)) + embeddings_by_rank = list(torch.chunk(embeddings, world_size, dim=0)) expect_results = torch.chunk( torch.stack(embeddings_by_rank, dim=0).sum(dim=0), - 2, + world_size, dim=0, ) kwargs_per_rank = [] @@ -606,7 +754,7 @@ def _run_test_dist( atol=atol, ) if qcomms_config is None: - assert_array_equal( + torch.testing.assert_close( # pyre-ignore input.grad.cpu().detach(), torch.ones(input.size()).div_(world_size), @@ -616,7 +764,6 @@ def _run_test_dist( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - @settings(deadline=30000) # pyre-ignore @given( qcomms_config=st.sampled_from( @@ -660,6 +807,7 @@ def _run_test_dist( ] ), ) + @settings(max_examples=3, deadline=45000) def test_pooled_embedding_reduce_scatter_v( self, qcomms_config: Optional[QCommsConfig] ) -> None: @@ -701,8 +849,10 @@ def _validate( input: torch.Tensor, world_size: int, ) -> None: - assert_array_equal(actual_output.cpu().detach(), expected_output.cpu().detach()) - assert_array_equal( + torch.testing.assert_close( + actual_output.cpu().detach(), expected_output.cpu().detach() + ) + torch.testing.assert_close( # pyre-fixme[16]: Optional type has no attribute `cpu`. input.grad.cpu().detach(), torch.ones(input.size()), @@ -765,9 +915,9 @@ def _generate_sequence_embedding_batch( ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: world_size = len(splits) - tensor_by_feature: Dict[ - str, List[torch.Tensor] - ] = {} # Model parallel, key as feature + tensor_by_feature: Dict[str, List[torch.Tensor]] = ( + {} + ) # Model parallel, key as feature tensor_by_rank: Dict[str, List[torch.Tensor]] = {} # Data parallel, key as rank emb_by_rank_feature = {} @@ -781,8 +931,8 @@ def _generate_sequence_embedding_batch( offset : offset + current_rank_batch_size ] offset += current_rank_batch_size - emb_by_rank_feature[f"{feature}_{str(rank)}"] = np.random.rand( - sum(current_stride_lengths), dim + emb_by_rank_feature[f"{feature}_{str(rank)}"] = torch.rand( + (sum(current_stride_lengths), dim) ).tolist() tensor_by_feature[f"{feature}"] = [] tensor_by_rank[f"{str(rank)}"] = [] @@ -802,7 +952,7 @@ def _generate_sequence_embedding_batch( out_tensor.append(torch.Tensor(v)) input_offsets = [0] + list(itertools.accumulate(splits)) - output_offsets = np.arange(0, world_size + 1, dtype=int).tolist() + output_offsets = torch.arange(0, world_size + 1, dtype=torch.int).tolist() regroup_in_tensor: List[torch.Tensor] = [] regroup_out_tensor: List[torch.Tensor] = [] @@ -849,13 +999,25 @@ def _run_test_dist( ) _input.requires_grad = True + sparse_features_recat = ( + _get_recat( + local_split=features_per_rank[rank], + num_splits=world_size, + device=device, + stagger=1, + batch_size_per_rank=batch_size_per_rank, + ) + if len(set(batch_size_per_rank)) > 1 + else None + ) + res = a2a( local_embs=_input, lengths=lengths_after_sdd_a2a, input_splits=input_splits, output_splits=output_splits, - sparse_features_recat=None, batch_size_per_rank=batch_size_per_rank, + sparse_features_recat=sparse_features_recat, ).wait() atol, rtol = None, None @@ -869,8 +1031,11 @@ def _run_test_dist( torch.testing.assert_close(res, output, rtol=rtol, atol=atol) res.backward(res) grad = _input.grad - # pyre-fixme[16]: Optional type has no attribute `cpu`. - assert_array_equal(_input.cpu().detach(), grad.cpu().detach()) + torch.testing.assert_close( + _input.cpu().detach(), + # pyre-fixme[16]: Optional type has no attribute `cpu`. + grad.cpu().detach() * world_size, + ) @unittest.skipIf( torch.cuda.device_count() <= 1, @@ -910,7 +1075,7 @@ def _run_test_dist( ] ), ) - @settings(max_examples=8, deadline=None) + @settings(max_examples=4, deadline=None) def test_sequence_embeddings( self, variable_batch_size: bool, @@ -1002,3 +1167,342 @@ def test_sequence_embeddings( world_size=world_size, kwargs_per_rank=kwargs_per_rank, ) + + +class VariableBatchPooledEmbeddingsAllToAllTest(MultiProcessTestBase): + @classmethod + def _run_test_dist( + cls, + rank: int, + world_size: int, + _input: torch.Tensor, + output: torch.Tensor, + backend: str, + emb_dim_per_rank_per_feature: List[List[int]], + batch_size_per_rank_per_feature: List[List[int]], + batch_size_per_feature_pre_a2a: List[int], + qcomms_config: Optional[QCommsConfig] = None, + ) -> None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = dist.group.WORLD + if backend == "gloo": + device = torch.device("cpu") + else: + device = torch.device(f"cuda:{rank}") + _input = _input.to(device=device) + output = output.to(device=device) + + codecs = get_qcomm_codecs(qcomms_config) + + a2a = VariableBatchPooledEmbeddingsAllToAll( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[_distributed_c10d.ProcessGroup]`. + pg=pg, + emb_dim_per_rank_per_feature=emb_dim_per_rank_per_feature, + device=device, + codecs=codecs, + ) + _input.requires_grad = True + res = a2a( + local_embs=_input, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature, + batch_size_per_feature_pre_a2a=batch_size_per_feature_pre_a2a, + ).wait() + res.backward(res) + + atol, rtol = None, None + if qcomms_config is not None: + atol, rtol = 0.01, 0.01 + if ( + qcomms_config.forward_precision == CommType.FP8 + or qcomms_config.backward_precision == CommType.FP8 + ): + atol, rtol = 0.05, 0.05 + + torch.testing.assert_close(res, output, rtol=rtol, atol=atol) + + torch.testing.assert_close( + _input.cpu().detach().div_(world_size), + # pyre-ignore + _input.grad.cpu().detach(), + atol=atol, + rtol=rtol, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + backend=st.sampled_from(["nccl"]), + features=st.integers(min_value=3, max_value=4), + B=st.integers(min_value=2, max_value=3), + is_reversed=st.booleans(), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.FP16, + ), + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.BF16, + ), + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.FP16, + backward_loss_scale=128.0, + ), + QCommsConfig( + forward_precision=CommType.FP32, + backward_precision=CommType.BF16, + ), + QCommsConfig( + forward_precision=CommType.FP8, + backward_precision=CommType.FP8, + ), + QCommsConfig( + forward_precision=CommType.FP8, + backward_precision=CommType.BF16, + ), + ] + ), + ) + @settings(max_examples=4, deadline=None) + def test_variable_batch_pooled_embeddings( + self, + backend: str, + B: int, + features: int, + is_reversed: bool, + qcomms_config: Optional[QCommsConfig], + ) -> None: + world_size = 2 + keys = [f"F{feature}" for feature in range(features)] + dims = random.sample([8, 16, 32] * features, features) + rank0_split = random.randint(1, features - 1) + splits = [rank0_split, features - rank0_split] + if is_reversed: + splits.reverse() + emb_dim_per_rank_per_feature = [] + f_id = 0 + for split in splits: + emb_dim_per_feature = [] + for _ in range(split): + emb_dim_per_feature.append(dims[f_id]) + f_id += 1 + emb_dim_per_rank_per_feature.append(emb_dim_per_feature) + + batch_size_per_rank_per_feature_pre_a2a = [] + for _ in range(world_size): + batch_size_per_feature = [random.randint(B, B + 4) for _ in keys] + batch_size_per_rank_per_feature_pre_a2a.append(batch_size_per_feature) + + batch_size_per_rank_per_feature_post_a2a_per_rank = [] + fid = 0 + for i in range(world_size): + batch_size_per_rank_per_feature_post_a2a = [[] for _ in range(world_size)] + split = splits[i] + for _ in range(split): + for j in range(world_size): + batch_size_per_rank_per_feature_post_a2a[j].append( + batch_size_per_rank_per_feature_pre_a2a[j][fid] + ) + fid += 1 + batch_size_per_rank_per_feature_post_a2a_per_rank.append( + batch_size_per_rank_per_feature_post_a2a + ) + + """ + before input dist: + r_0 + f_0: [1, 2], [3, 4] + f_1: [5, 6] + f_2: [1], [2], [3] + + r_1 + f_0: [1, 2] + f_1: [5, 6], [3, 4] + f_2: [1], [2] + + after input dist (splits: [1, 2]): + r_0 + f_0: [1, 2], [3, 4], [1, 2] + + r_1 + f_1: [5, 6], [5, 6], [3, 4] + f_2: [1], [2], [3], [1], [2] + + output layout + r_0: + [r_0_f_0_s_0, r_0_f_0_s_1, r_1_f_0_s_0] + + r_1: + [r_0_f_1_s_0, r_0_f_2_s_0, r_0_f_2_s_1, r_0_f_2_s_2, + r_1_f_1_s_0, r_1_f_1_s_1, r_1_f_2_s_0, r_1_f_2_s_1] + + after output dist + r_0: + [r_0_f_0_s_0, r_0_f_0_s_1, r_0_f_1_s_0, r_0_f_2_s_0, r_0_f_2_s_1, r_0_f_2_s_2] + + r_1: + [r_1_f_0_s_0, r_1_f_1_s_0, r_1_f_1_s_1, r_1_f_2_s_0, r_1_f_2_s_1] + """ + + def _generate_variable_batch_pooled_embedding_batch( + keys: List[str], + dims: List[int], + splits: List[int], + batch_size_per_rank_per_feature: List[List[int]], + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + world_size = len(splits) + offsets = [0] + list(itertools.accumulate(splits)) + local_embs = {} + + feature_ind = 0 + for key, dim in zip(keys, dims): + for rank in range(world_size): + local_batch_size = batch_size_per_rank_per_feature[rank][ + feature_ind + ] + if rank not in local_embs: + local_embs[rank] = {} + local_embs[rank][key] = torch.rand( + dim * local_batch_size, dtype=torch.float + ) + feature_ind += 1 + + in_tensor: List[torch.Tensor] = [] + out_tensor: List[torch.Tensor] = [] + for i in range(world_size): + in_keys = keys[offsets[i] : offsets[i + 1]] + input_tensor_list = [] + for rank in range(world_size): + input_tensor_list += [local_embs[rank][key] for key in in_keys] + input_tensor = torch.cat(input_tensor_list) + in_tensor.append(input_tensor) + + output_tensor = torch.cat([local_embs[i][key] for key in keys]) + out_tensor.append(output_tensor) + + return in_tensor, out_tensor + + _input, output = _generate_variable_batch_pooled_embedding_batch( + keys=keys, + dims=dims, + splits=splits, + batch_size_per_rank_per_feature=batch_size_per_rank_per_feature_pre_a2a, + ) + + kwargs_per_rank = [] + for rank in range(world_size): + kwargs_per_rank.append( + { + "_input": _input[rank], + "output": output[rank], + "backend": backend, + "emb_dim_per_rank_per_feature": emb_dim_per_rank_per_feature, + "batch_size_per_rank_per_feature": batch_size_per_rank_per_feature_post_a2a_per_rank[ + rank + ], + "batch_size_per_feature_pre_a2a": batch_size_per_rank_per_feature_pre_a2a[ + rank + ], + "qcomms_config": qcomms_config, + } + ) + + self._run_multi_process_test_per_rank( + callable=self._run_test_dist, + world_size=world_size, + kwargs_per_rank=kwargs_per_rank, + ) + + +class TestJaggedTensorAllToAll(MultiProcessTestBase): + @staticmethod + def _test_jt_all_to_all( + rank: int, + world_size: int, + ) -> None: + backend = "nccl" + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + device = ctx.device + if ctx.rank == 0: + # [ + # [1], [2,2], [3,3,3], [4,4,4,4] + # ] + jt = JaggedTensor( + values=torch.tensor( + [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], dtype=torch.int, device=device + ), + lengths=torch.tensor( + [1, 2, 3, 4], dtype=torch.int32, device=device + ), + ) + input_splits = torch.tensor([3, 1], dtype=torch.int32, device=device) + output_splits = torch.tensor([3, 2], dtype=torch.int32, device=device) + else: + # [ + # [5,5,5,5,5], [6,6,6,6,6,6], [7,7,7,7,7,7,7] + # ] + jt = JaggedTensor( + values=torch.tensor( + [5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7], + device=device, + dtype=torch.int, + ), + lengths=torch.tensor([5, 6, 7], dtype=torch.int, device=device), + ) + input_splits = torch.tensor([2, 1], dtype=torch.int32, device=device) + output_splits = torch.tensor([1, 1], dtype=torch.int32, device=device) + + jt_all_to_all = JaggedTensorAllToAll( + jt, + num_items_to_send=input_splits, + num_items_to_receive=output_splits, + # pyre-fixme[6]: For 4th argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=ctx.pg, + ) + + jt_out = jt_all_to_all.wait() + + torch.testing.assert_close( + jt_out.values(), + torch.tensor( + ( + [1, 2, 2, 3, 3, 3, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6] + if ctx.rank == 0 + else [4, 4, 4, 4, 7, 7, 7, 7, 7, 7, 7] + ), + dtype=torch.int, + device=device, + ), + ) + + torch.testing.assert_close( + jt_out.lengths(), + torch.tensor( + [1, 2, 3, 5, 6] if ctx.rank == 0 else [4, 7], + dtype=torch.int, + device=device, + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_jt_all_to_all( + self, + ) -> None: + world_size = 2 + self._run_multi_process_test( + callable=self._test_jt_all_to_all, world_size=world_size + ) diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py new file mode 100644 index 000000000..f9a07fc50 --- /dev/null +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import copy + +import random +import unittest + +from typing import Any, Dict, List, Optional, Union + +import hypothesis.strategies as st + +import torch + +from hypothesis import given, settings, Verbosity +from torch import nn + +from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection + +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + get_module_to_default_sharders, + table_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_sharding import copy_state_dict + +from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, + ParameterSharding, + ShardingEnv, + ShardingType, +) +from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig + +from torchrec.test_utils import skip_if_asan_class +from torchrec.types import DataType + + +# Utils: +def table_name(i: int) -> str: + return "table_" + str(i) + + +def feature_name(i: int) -> str: + return "feature_" + str(i) + + +def generate_embedding_bag_config( + data_type: DataType, + num_tables: int = 3, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> List[EmbeddingBagConfig]: + embedding_bag_config = [] + for i in range(num_tables): + embedding_bag_config.append( + EmbeddingBagConfig( + name=table_name(i), + feature_names=[feature_name(i)], + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + data_type=data_type, + ), + ) + return embedding_bag_config + + +def generate_rank_placements( + world_size: int, + num_tables: int, + ranks_per_tables: List[int], +) -> List[List[int]]: + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + placements = [] + max_rank = world_size - 1 + if ranks_per_tables == [0]: + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] + for i in range(num_tables): + ranks_per_table = ranks_per_tables[i] + placement = sorted(random.sample(range(world_size), ranks_per_table)) + placements.append(placement) + return placements + + +def create_test_initial_state_dict( + sharded_module_type: nn.Module, + num_tables: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> Dict[str, torch.Tensor]: + """ + Helpful for debugging: + + initial_state_dict = { + "embedding_bags.table_0.weight": torch.tensor( + [ + [1] * 16, + [2] * 16, + [3] * 16, + [4] * 16, + ], + ), + "embedding_bags.table_1.weight": torch.tensor( + [ + [101] * 16, + [102] * 16, + [103] * 16, + [104] * 16, + ], + dtype=data_type_to_dtype(data_type), + ), + ... + } + """ + + initial_state_dict = {} + for i in range(num_tables): + # pyre-ignore + extended_name = sharded_module_type.extend_shard_name(table_name(i)) + initial_state_dict[extended_name] = torch.tensor( + [[j + (i * 100)] * embedding_dim for j in range(num_embeddings)], + dtype=data_type_to_dtype(data_type), + ) + + return initial_state_dict + + +def are_sharded_ebc_modules_identical( + module1: ShardedEmbeddingBagCollection, + module2: ShardedEmbeddingBagCollection, +) -> None: + # Check if both modules have the same parameters + params1 = list(module1.named_parameters()) + params2 = list(module2.named_parameters()) + + assert len(params1) == len(params2) + + for param1, param2 in zip(params1, params2): + # Check parameter names + assert param1[0] == param2[0] + # Check parameter values + assert torch.allclose(param1[1], param2[1]) + + # Check if both modules have the same buffers + buffers1 = list(module1.named_buffers()) + buffers2 = list(module2.named_buffers()) + + assert len(buffers1) == len(buffers2) + + for buffer1, buffer2 in zip(buffers1, buffers2): + assert buffer1[0] == buffer2[0] # Check buffer names + assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values + + # Hard-coded attributes for EmbeddingBagCollection + attribute_list = [ + "_module_fqn", + "_table_names", + "_pooling_type_to_rs_features", + "_output_dtensor", + "_sharding_types", + "_is_weighted", + "_embedding_names", + "_embedding_dims", + "_feature_splits", + "_features_order", + "_uncombined_embedding_names", + "_uncombined_embedding_dims", + "_has_mean_pooling_callback", + "_kjt_key_indices", + "_has_uninitialized_input_dist", + "_has_features_permute", + "_dim_per_key", # Tensor + "_inverse_indices_permute_indices", # Tensor + "_kjt_inverse_order", # Tensor + "_kt_key_ordering", # Tensor + # Non-primitive types which can be compared + "module_sharding_plan", + "_table_name_to_config", + # Excluding the non-primitive types that cannot be compared + # "sharding_type_to_sharding_infos", + # "_embedding_shardings" + # "_input_dists", + # "_lookups", + # "_output_dists", + # "_optim", + ] + + for attr in attribute_list: + assert hasattr(module1, attr) and hasattr(module2, attr) + + val1 = getattr(module1, attr) + val2 = getattr(module2, attr) + + assert type(val1) is type(val2) + if type(val1) is torch.Tensor: + torch.testing.assert_close(val1, val2) + else: + assert val1 == val2 + + +def output_sharding_plan_delta( + old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan +) -> EmbeddingModuleShardingPlan: + assert len(old_plan) == len(new_plan) + return_plan = copy.deepcopy(new_plan) + for shard_name, old_param in old_plan.items(): + if shard_name not in return_plan: + raise ValueError(f"Shard {shard_name} not found in new plan") + new_param = return_plan[shard_name] + old_ranks = old_param.ranks + new_ranks = new_param.ranks + if old_ranks == new_ranks: + del return_plan[shard_name] + + return return_plan + + +def _test_ebc_resharding( + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + backend: str, + module_sharding_plan: EmbeddingModuleShardingPlan, + new_module_sharding_plan: EmbeddingModuleShardingPlan, + local_size: Optional[int] = None, +) -> None: + """ + Distributed call to test resharding for ebc by creating 2 models with identical config and + states: + m1 sharded with new_module_sharding_plan + m2 sharded with module_sharding_plan, then resharded with new_module_sharding_plan + + Expects m1 and resharded m2 to be the same, and predictions outputted from the same KJT + inputs to be the same. + + TODO: modify to include other modules once dynamic sharding is built out. + """ + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + # Set seed to be 0 to ensure models have the same initialization across ranks + torch.manual_seed(0) + m1 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + m2 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + if initial_state_dict is not None: + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + + # Load initial State - making sure models are identical + m1.load_state_dict(initial_state_dict) + + m2.load_state_dict(initial_state_dict) + + else: + # Note this is the only correct behavior due to setting random seed to 0 above + # Otherwise the weights generated in EBC initialization will be different on + # Each rank, resulting in different behavior after resharding + copy_state_dict( + loc=m2.state_dict(), + glob=m1.state_dict(), + ) + + sharder = get_module_to_default_sharders()[type(m1)] + + # pyre-ignore + env = ShardingEnv.from_process_group(ctx.pg) + + sharded_m1 = sharder.shard( + module=m1, + params=new_module_sharding_plan, + env=env, + device=ctx.device, + ) + + sharded_m2 = sharder.shard( + module=m1, + params=module_sharding_plan, + env=env, + device=ctx.device, + ) + + new_module_sharding_plan_delta = output_sharding_plan_delta( + module_sharding_plan, new_module_sharding_plan + ) + + # pyre-ignore + resharded_m2 = sharder.reshard( + sharded_module=sharded_m2, + changed_shard_to_params=new_module_sharding_plan_delta, + env=env, + device=ctx.device, + ) + + are_sharded_ebc_modules_identical(sharded_m1, resharded_m2) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + # For current test model and inputs, the prediction should be the exact same + # rtol = 0 + # atol = 0 + + for _ in range(world_size): + # sharded model + # each rank gets a subbatch + sharded_m1_pred_kt_no_dict = sharded_m1(kjt_input_per_rank[ctx.rank]) + resharded_m2_pred_kt_no_dict = resharded_m2(kjt_input_per_rank[ctx.rank]) + + sharded_m1_pred_kt = sharded_m1_pred_kt_no_dict.to_dict() + resharded_m2_pred_kt = resharded_m2_pred_kt_no_dict.to_dict() + sharded_m1_pred = torch.stack( + [sharded_m1_pred_kt[feature] for feature in feature_keys] + ) + + resharded_m2_pred = torch.stack( + [resharded_m2_pred_kt[feature] for feature in feature_keys] + ) + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close(sharded_m1_pred.cpu(), resharded_m2_pred.cpu()) + + sharded_m1_pred.sum().backward() + resharded_m2_pred.sum().backward() + + +@skip_if_asan_class +class MultiRankDynamicShardingTest(MultiProcessTestBase): + def _run_ebc_resharding_test( + self, + per_param_sharding: Dict[str, ParameterSharding], + new_per_param_sharding: Dict[str, ParameterSharding], + num_tables: int, + world_size: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, + use_debug_state_dict: bool = False, # Turn on to use dummy values for initial state dict + ) -> None: + embedding_bag_config = generate_embedding_bag_config( + data_type, num_tables, embedding_dim, num_embeddings + ) + + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + new_module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + # Row-wise not supported on gloo + if ( + not torch.cuda.is_available() + and new_module_sharding_plan["table_0"].sharding_type + == ShardingType.ROW_WISE.value + ): + return + + kjt_input_per_rank = [ + ModelInput.create_standard_kjt( + batch_size=2, + tables=embedding_bag_config, + ) + for _ in range(world_size) + ] + + initial_state_dict = None + if use_debug_state_dict: + # initial_state_dict filled with deterministic dummy values + initial_state_dict = create_test_initial_state_dict( + ShardedEmbeddingBagCollection, # pyre-ignore + num_tables, + data_type, + embedding_dim, + num_embeddings, + ) + + self._run_multi_process_test( + callable=_test_ebc_resharding, + world_size=world_size, + tables=embedding_bag_config, + initial_state_dict=initial_state_dict, + kjt_input_per_rank=kjt_input_per_rank, + backend="nccl" if torch.cuda.is_available() else "gloo", + module_sharding_plan=module_sharding_plan, + new_module_sharding_plan=new_module_sharding_plan, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + num_tables=st.sampled_from([2, 3, 4]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + world_size=st.sampled_from([2, 4]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_dynamic_sharding_ebc_tw( + self, + num_tables: int, + data_type: DataType, + world_size: int, + ) -> None: + # Tests EBC dynamic sharding implementation for TW + + # Table wise can only have 1 rank allocated per table: + ranks_per_tables = [1 for _ in range(num_tables)] + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + + while new_ranks == old_ranks: + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + per_param_sharding = {} + new_per_param_sharding = {} + + # Construct parameter shardings + for i in range(num_tables): + per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i][0]) + new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i][0]) + + self._run_ebc_resharding_test( + per_param_sharding, + new_per_param_sharding, + num_tables, + world_size, + data_type, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + num_tables=st.sampled_from([2, 3, 4]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + world_size=st.sampled_from([3, 4]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_dynamic_sharding_ebc_cw( + self, + num_tables: int, + data_type: DataType, + world_size: int, + ) -> None: + # Tests EBC dynamic sharding implementation for CW + + # Force the ranks per table to be consistent + ranks_per_tables = [ + random.randint(1, world_size - 1) for _ in range(num_tables) + ] + + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + while new_ranks == old_ranks: + old_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + per_param_sharding = {} + new_per_param_sharding = {} + + # Construct parameter shardings + for i in range(num_tables): + per_param_sharding[table_name(i)] = column_wise(ranks=old_ranks[i]) + new_per_param_sharding[table_name(i)] = column_wise(ranks=new_ranks[i]) + + self._run_ebc_resharding_test( + per_param_sharding, + new_per_param_sharding, + num_tables, + world_size, + data_type, + ) diff --git a/torchrec/distributed/tests/test_emb_dim_bucketer.py b/torchrec/distributed/tests/test_emb_dim_bucketer.py new file mode 100644 index 000000000..2ea8c2f59 --- /dev/null +++ b/torchrec/distributed/tests/test_emb_dim_bucketer.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import random +import unittest +from typing import List, Tuple + +from torchrec.distributed.embedding_dim_bucketer import ( + EmbDimBucketer, + EmbDimBucketerPolicy, +) + +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ShardedEmbeddingTable, +) +from torchrec.modules.embedding_configs import DataType + + +class TestEmbDimBucketer(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def gen_tables(self) -> Tuple[List[ShardedEmbeddingTable], int]: + num_tables = 103 + num_buckets = 11 + embeddings: List[ShardedEmbeddingTable] = [] + buckets = random.sample(range(1024), num_buckets) + + for i in range(num_tables): + local_cols = buckets[i % num_buckets] + local_rows = random.randint(100, 500000) + embeddings.append( + ShardedEmbeddingTable( + name=f"table_{i}", + local_cols=local_cols, + local_rows=local_rows, + embedding_dim=local_cols * random.randint(1, 20), + num_embeddings=local_rows * random.randint(1, 20), + data_type=DataType.FP16, + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + ) + ) + return embeddings, len(buckets) + + def gen_single_dim_tables(self) -> List[ShardedEmbeddingTable]: + num_tables = 47 + embeddings: List[ShardedEmbeddingTable] = [] + for i in range(num_tables): + local_cols = 16 + local_rows = random.randint(100, 500000) + embeddings.append( + ShardedEmbeddingTable( + name=f"table_{i}", + local_cols=local_cols, + local_rows=local_rows, + embedding_dim=local_cols * random.randint(1, 20), + num_embeddings=local_rows * random.randint(1, 20), + data_type=DataType.FP16, + ) + ) + return embeddings + + def test_single_bucket_tables(self) -> None: + embedding_tables = self.gen_single_dim_tables() + emb_dim_bucketer = EmbDimBucketer( + embedding_tables, EmbDimBucketerPolicy.CACHELINE_BUCKETS + ) + self.assertTrue(emb_dim_bucketer.bucket_count() == 1) + + def test_single_bucket_policy(self) -> None: + embedding_tables, _ = self.gen_tables() + emb_dim_bucketer = EmbDimBucketer( + embedding_tables, EmbDimBucketerPolicy.SINGLE_BUCKET + ) + self.assertTrue(emb_dim_bucketer.bucket_count() == 1) + + def test_cacheline_bucket_policy(self) -> None: + embedding_tables, _ = self.gen_tables() + emb_dim_bucketer = EmbDimBucketer( + embedding_tables, EmbDimBucketerPolicy.CACHELINE_BUCKETS + ) + for i in range(emb_dim_bucketer.bucket_count()): + self.assertTrue(i in emb_dim_bucketer.emb_dim_buckets.values()) + + def test_all_bucket_policy(self) -> None: + embedding_tables, num_buckets = self.gen_tables() + emb_dim_bucketer = EmbDimBucketer( + embedding_tables, EmbDimBucketerPolicy.ALL_BUCKETS + ) + + self.assertTrue(emb_dim_bucketer.bucket_count() == num_buckets) + + for i in range(emb_dim_bucketer.bucket_count()): + self.assertTrue(i in emb_dim_bucketer.emb_dim_buckets.values()) diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py new file mode 100644 index 000000000..466cf1a16 --- /dev/null +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import random +import unittest +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import hypothesis.strategies as st +import torch + +from hypothesis import given, settings +from torchrec.distributed.embedding import EmbeddingCollectionContext + +from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel + +from torchrec.distributed.embedding_sharding import ( + _get_compute_kernel_type, + _get_grouping_fused_params, + _get_weighted_avg_cache_load_factor, + _prefetch_and_cached, + _set_sharding_context_post_a2a, + group_tables, +) + +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + ShardedEmbeddingTable, +) +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.modules.embedding_configs import DataType, PoolingType +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase): + def test_get_avg_cache_load_factor_hbm(self) -> None: + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=1000, + embedding_dim=MagicMock(), + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor in cache_load_factors + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + self.assertIsNone(weighted_avg_cache_load_factor) + + def test_get_avg_cache_load_factor(self) -> None: + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=1000, + embedding_dim=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor in cache_load_factors + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + expected_avg = sum(cache_load_factors) / len(cache_load_factors) + self.assertIsNotNone(weighted_avg_cache_load_factor) + self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_avg) + + def test_get_weighted_avg_cache_load_factor(self) -> None: + hash_sizes = [random.randint(100, 1000) for _ in range(5)] + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=hash_size, + embedding_dim=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes) + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + expected_weighted_avg = sum( + cache_load_factor * hash_size + for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes) + ) / sum(hash_sizes) + + self.assertIsNotNone(weighted_avg_cache_load_factor) + self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_weighted_avg) + + +class TestGetGroupingFusedParams(unittest.TestCase): + def test_get_grouping_fused_params(self) -> None: + fused_params_groups = [ + None, + {}, + {"stochastic_rounding": False}, + {"stochastic_rounding": False, "cache_load_factor": 0.4}, + ] + grouping_fused_params_groups = [ + _get_grouping_fused_params(fused_params, "table_1") + for fused_params in fused_params_groups + ] + expected_grouping_fused_params_groups = [ + None, + {}, + {"stochastic_rounding": False}, + {"stochastic_rounding": False}, + ] + + self.assertEqual( + grouping_fused_params_groups, expected_grouping_fused_params_groups + ) + + +class TestPerTBECacheLoadFactor(unittest.TestCase): + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + embedding_dim=st.sampled_from(list(range(160, 320, 40))), + pooling_type=st.sampled_from(list(PoolingType)), + ) + @settings(max_examples=10, deadline=10000) + def test_per_tbe_clf_weighted_average( + self, + data_type: DataType, + has_feature_processor: bool, + embedding_dim: int, + pooling_type: PoolingType, + ) -> None: + compute_kernels = [ + EmbeddingComputeKernel.FUSED_UVM_CACHING, + EmbeddingComputeKernel.FUSED_UVM_CACHING, + EmbeddingComputeKernel.FUSED, + EmbeddingComputeKernel.FUSED_UVM, + ] + fused_params_groups = [ + {"cache_load_factor": 0.5}, + {"cache_load_factor": 0.3}, + {"cache_load_factor": 0.9}, # hbm table, would have no effect + {"cache_load_factor": 0.4}, # uvm table, would have no effect + ] + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_groups[i], + compute_kernel=compute_kernels[i], + embedding_dim=embedding_dim, + num_embeddings=10000 * (2 * i + 1), # 10000 and 30000 + ) + for i in range(4) + ] + + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + + # assert that they are grouped together + self.assertEqual(len(table_groups), 1) + + table_group = table_groups[0] + self.assertIsNotNone(table_group.fused_params) + self.assertEqual(table_group.fused_params.get("cache_load_factor"), 0.35) + + +def _get_table_names_by_groups( + embedding_tables: List[ShardedEmbeddingTable], +) -> List[List[str]]: + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [embedding_tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + return [table_group.table_names() for table_group in table_groups] + + +class TestGroupTablesPerRank(unittest.TestCase): + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + fused_params_group=st.sampled_from( + [ + { + "cache_load_factor": 0.5, + "prefetch_pipeline": False, + }, + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + ] + ), + local_dim=st.sampled_from(list(range(160, 320, 40))), + num_cw_shards=st.sampled_from([1, 2]), + pooling_type=st.sampled_from(list(PoolingType)), + compute_kernel=st.sampled_from(list(EmbeddingComputeKernel)), + ) + @settings(max_examples=10, deadline=10000) + def test_should_group_together( + self, + data_type: DataType, + has_feature_processor: bool, + fused_params_group: Dict[str, Any], + local_dim: int, + num_cw_shards: int, + pooling_type: PoolingType, + compute_kernel: EmbeddingComputeKernel, + ) -> None: + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_group, + compute_kernel=compute_kernel, + embedding_dim=local_dim * num_cw_shards, + local_cols=local_dim, + num_embeddings=10000, + ) + for i in range(2) + ] + + expected_table_names_by_groups = [["table_0", "table_1"]] + self.assertEqual( + _get_table_names_by_groups(tables), + expected_table_names_by_groups, + ) + + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + local_dim=st.sampled_from(list(range(160, 320, 40))), + num_cw_shards=st.sampled_from([1, 2]), + pooling_type=st.sampled_from(list(PoolingType)), + compute_kernel=st.sampled_from(list(EmbeddingComputeKernel)), + ) + @settings(max_examples=10, deadline=10000) + def test_should_group_together_with_prefetch( + self, + data_type: DataType, + has_feature_processor: bool, + local_dim: int, + num_cw_shards: int, + pooling_type: PoolingType, + compute_kernel: EmbeddingComputeKernel, + ) -> None: + fused_params_groups = [ + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + { + "cache_load_factor": 0.5, + "prefetch_pipeline": True, + }, + ] + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_groups[i], + compute_kernel=compute_kernel, + embedding_dim=local_dim * num_cw_shards, + local_cols=local_dim, + num_embeddings=10000, + ) + for i in range(2) + ] + + expected_table_names_by_groups = [["table_0", "table_1"]] + self.assertEqual( + _get_table_names_by_groups(tables), + expected_table_names_by_groups, + ) + + # pyre-ignore[56] + @given( + data_types=st.lists( + st.sampled_from([DataType.FP16, DataType.FP32]), + min_size=2, + max_size=2, + unique=True, + ), + has_feature_processors=st.lists( + st.sampled_from([False, True]), min_size=2, max_size=2, unique=True + ), + fused_params_group=st.sampled_from( + [ + { + "cache_load_factor": 0.5, + "prefetch_pipeline": True, + }, + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + ], + ), + local_dim=st.lists( + st.sampled_from([4, 10, 40]), + min_size=2, + max_size=2, + unique=True, + ), + embedding_dims=st.lists( + st.sampled_from(list(range(160, 320, 40))), + min_size=2, + max_size=2, + unique=True, + ), + pooling_types=st.lists( + st.sampled_from(list(PoolingType)), min_size=2, max_size=2, unique=True + ), + compute_kernels=st.lists( + st.sampled_from(list(EmbeddingComputeKernel)), + min_size=2, + max_size=2, + unique=True, + ), + distinct_key=st.sampled_from( + [ + "data_type", + "has_feature_processor", + "embedding_dim", + "local_dim", + "pooling_type", + "compute_kernel", + ] + ), + ) + @settings(max_examples=100, deadline=10000) + def test_should_not_group_together( + self, + data_types: List[DataType], + has_feature_processors: List[bool], + fused_params_group: Dict[str, Any], + local_dim: List[int], + embedding_dims: List[int], + pooling_types: List[PoolingType], + compute_kernels: List[EmbeddingComputeKernel], + distinct_key: str, + ) -> None: + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=( + data_types[i] if distinct_key == "data_type" else data_types[0] + ), + pooling=( + pooling_types[i] + if distinct_key == "pooling_type" + else pooling_types[0] + ), + has_feature_processor=( + has_feature_processors[i] + if distinct_key == "has_feature_processor" + else has_feature_processors[0] + ), + fused_params=fused_params_group, # can't hash dicts + compute_kernel=( + compute_kernels[i] + if distinct_key == "compute_kernel" + else compute_kernels[0] + ), + embedding_dim=( + embedding_dims[i] + if distinct_key == "embedding_dim" + else embedding_dims[0] + ), + local_cols=( + local_dim[i] if distinct_key == "local_dim" else local_dim[0] + ), + num_embeddings=10000, + ) + for i in range(2) + ] + + if distinct_key == "compute_kernel" and _get_compute_kernel_type( + compute_kernels[0] + ) == _get_compute_kernel_type(compute_kernels[1]): + # Typically, a table with same group of kernel (e.g. FUSED vs FUSED_UVM) + # would be grouped together. But if one of them are related to CACHE, + # we'll group them separately because we don't want to add the burden of + # prefetch() + if _prefetch_and_cached(tables[0]) != _prefetch_and_cached(tables[1]): + self.assertEqual( + sorted(_get_table_names_by_groups(tables)), + [["table_0"], ["table_1"]], + ) + else: + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_1"]], + ) + return + + # emb dim bucketizier only in use when computer kernel is caching. Otherwise + # they shall be grouped into the same bucket even with different dimensions + if distinct_key == "local_dim" and not _prefetch_and_cached(tables[0]): + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_1"]], + ) + return + + # We never separate-group by embedding dim. So we don't care if it's distinct + # or not. Just group them together. + if distinct_key == "embedding_dim": + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_1"]], + ) + return + + # If both kernels are quantized, we assume this is inference which we no longer split by data_type + # So if other attributes are the same between the two tables (regardless of data type), we combine them + if ( + tables[0].compute_kernel == EmbeddingComputeKernel.QUANT + and tables[1].compute_kernel == EmbeddingComputeKernel.QUANT + and tables[0].pooling == tables[1].pooling + and tables[0].has_feature_processor == tables[1].has_feature_processor + ): + + self.assertEqual( + sorted(_get_table_names_by_groups(tables)), + [["table_0", "table_1"]], + ) + return + + self.assertEqual( + sorted(_get_table_names_by_groups(tables)), + [["table_0"], ["table_1"]], + ) + + def test_use_one_tbe_per_table( + self, + ) -> None: + + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=DataType.FP16, + pooling=PoolingType.SUM, + has_feature_processor=False, + fused_params={"use_one_tbe_per_table": i % 2 != 0}, + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + embedding_dim=10, + local_cols=5, + num_embeddings=10000, + ) + for i in range(5) + ] + + # table_1 has two shards in the rank + tables.append( + ShardedEmbeddingTable( + name="table_1", + data_type=DataType.FP16, + pooling=PoolingType.SUM, + has_feature_processor=False, + fused_params={"use_one_tbe_per_table": True}, + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + embedding_dim=10, + local_cols=3, + num_embeddings=10000, + ) + ) + + # Even tables should all be grouped in a single TBE, odd tables should be in + # their own TBEs. + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_2", "table_4"], ["table_1", "table_1"], ["table_3"]], + ) + + def test_set_sharding_context_post_a2a(self) -> None: + kjts = [ + KeyedJaggedTensor( + keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"], + values=torch.IntTensor([1] * 10), + lengths=torch.IntTensor([1] * 10), + stride_per_key_per_rank=[ + [1, 2], + [1, 2], + [2, 3], + [5, 7], + [3, 4], + ], + ), + KeyedJaggedTensor( + keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"], + values=torch.IntTensor([1] * 10), + lengths=torch.IntTensor([1] * 10), + stride_per_key_per_rank=[[3, 1], [5, 2], [7, 3], [1, 2], [6, 8]], + ), + ] + for kjt in kjts: + kjt._variable_stride_per_key = True + + ctx = EmbeddingCollectionContext( + sharding_contexts=[ + SequenceShardingContext(batch_size_per_rank_per_feature=[]), + SequenceShardingContext(batch_size_per_rank_per_feature=[]), + ] + ) + results = [ + [[1, 1, 2, 5, 3], [2, 2, 3, 7, 4]], + [[3, 5, 7, 1, 6], [1, 2, 3, 2, 8]], + ] + _set_sharding_context_post_a2a(kjts, ctx) + for context, result in zip(ctx.sharding_contexts, results): + self.assertEqual(context.batch_size_per_rank_per_feature, result) diff --git a/torchrec/distributed/tests/test_embedding_types.py b/torchrec/distributed/tests/test_embedding_types.py new file mode 100644 index 000000000..db9f660b7 --- /dev/null +++ b/torchrec/distributed/tests/test_embedding_types.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, List + +import torch +from torchrec.distributed.embedding_types import KJTList, ShardedEmbeddingModule +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionContext +from torchrec.distributed.types import Awaitable, LazyAwaitable + +Out = Dict[str, torch.Tensor] +CompIn = KJTList +DistOut = List[torch.Tensor] +ShrdCtx = EmbeddingBagCollectionContext + + +class FakeShardedEmbeddingModule(ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx]): + def __init__(self) -> None: + super().__init__() + self._lookups = [ + torch.nn.Module(), + torch.nn.Module(), + ] + + # pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit + # return value of `None`. + def create_context(self) -> ShrdCtx: + pass + + def input_dist( + self, + ctx: ShrdCtx, + # pyre-ignore[2] + *input, + # pyre-ignore[2] + **kwargs, + # pyre-fixme[7]: Expected `Awaitable[Awaitable[KJTList]]` but got implicit + # return value of `None`. + ) -> Awaitable[Awaitable[CompIn]]: + pass + + # pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of `None`. + def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut: + pass + + # pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got implicit + # return value of `None`. + def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]: + pass + + +class TestShardedEmbeddingModule(unittest.TestCase): + def test_train_mode(self) -> None: + embedding_module = FakeShardedEmbeddingModule() + for mode in [True, False]: + with self.subTest(mode=mode): + embedding_module.train(mode) + self.assertEqual(embedding_module.training, mode) + for lookup in embedding_module._lookups: + self.assertEqual(lookup.training, mode) diff --git a/torchrec/distributed/tests/test_fbgemm_qcomm_codec.py b/torchrec/distributed/tests/test_fbgemm_qcomm_codec.py index 36d2c26c1..b4aa290b1 100644 --- a/torchrec/distributed/tests/test_fbgemm_qcomm_codec.py +++ b/torchrec/distributed/tests/test_fbgemm_qcomm_codec.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import Optional, Tuple @@ -92,3 +94,46 @@ def test_quantized_comm_codec( rtol=rtol, atol=atol, ) + + @settings(deadline=4000) + # pyre-ignore + @given( + row_size=st.integers(4, 256), + col_size=st.integers(4, 256), + rand_seed=st.integers(0, 65534), + ) + def test_mx4_comm_codec( + self, + row_size: int, + col_size: int, + rand_seed: int, + ) -> None: + + torch.manual_seed(rand_seed) + shape = (row_size, col_size) + input_tensor = torch.rand(shape, requires_grad=False) * 2 - 1 + + quant_codec = get_qcomm_codecs( + QCommsConfig( + forward_precision=CommType.MX4, + ) + ) + dim_sum_per_rank = [shape[1]] + ctx = quant_codec.forward.create_context() + + rank = 0 + quant_codec.forward.padded_size(input_tensor, dim_sum_per_rank, rank, ctx) + quant_tensor = quant_codec.forward.encode(input_tensor, ctx) + output_tensor = quant_codec.forward.decode(quant_tensor, ctx) + output_tensor = output_tensor.view(shape[0], ctx.padded_dim_sum_per_rank[rank]) + output_tensor = output_tensor[:, : shape[1]] + + rtol = 0.1 + atol = 0.15 + + torch.testing.assert_close( + input_tensor.detach().cpu(), + output_tensor.detach().cpu(), + rtol=rtol, + atol=atol, + ) diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py new file mode 100644 index 000000000..130776919 --- /dev/null +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest +from operator import xor +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from hypothesis import given, settings, strategies as st, Verbosity +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec import distributed as trec_dist +from torchrec.distributed.fp_embeddingbag import ( + FeatureProcessedEmbeddingBagCollectionSharder, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + data_parallel, + table_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( + create_module_and_freeze, + get_configs, + get_kjt_inputs, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingBagConfig + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +def get_unsharded_and_sharded_module( + tables: List[EmbeddingBagConfig], + sharder: ModuleSharder[nn.Module], + use_dmp: bool, + use_fp_collection: bool, + init_device: torch.device, + ctx: MultiProcessContext, +) -> Tuple[nn.Module, nn.Module]: + sparse_arch = create_module_and_freeze( + tables, + use_fp_collection=use_fp_collection, + device=init_device, + ) + + apply_optimizer_in_backward( + torch.optim.SGD, + sparse_arch._fp_ebc._embedding_bag_collection.embedding_bags.parameters(), + {"lr": 1.0}, + ) + + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._fp_ebc, + per_param_sharding={ + "table_0": column_wise(ranks=[0, 1]), + "table_1": table_wise(rank=0), + "table_2": data_parallel(), + "table_3": column_wise(ranks=[0, 0, 1]), + }, + local_size=ctx.local_size, + world_size=ctx.world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + if use_dmp: + sharded_sparse_arch = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_fp_ebc": module_sharding_plan}), + # pyre-ignore + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + else: + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"._fp_ebc": module_sharding_plan}), + # pyre-ignore + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + from torch.distributed._composable.replicate import replicate + + replicate( + sharded_sparse_arch, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_bag_collection`. + ignored_modules=[sharded_sparse_arch._fp_ebc._embedding_bag_collection], + process_group=ctx.pg, + gradient_as_bucket_view=True, + device_ids=None if ctx.device.type == "cpu" else [ctx.device], + broadcast_buffers=False, + ) + return sparse_arch, sharded_sparse_arch + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + set_gradient_division: bool, + local_size: Optional[int] = None, + use_dmp: bool = False, + use_fp_collection: bool = False, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + trec_dist.comm_ops.set_gradient_division(set_gradient_division) + + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + + sparse_arch, sharded_sparse_arch = get_unsharded_and_sharded_module( + tables, + sharder, + use_dmp, + use_fp_collection, + init_device=ctx.device, + ctx=ctx, + ) + + copy_state_dict( + sharded_sparse_arch.state_dict(), + copy.deepcopy(sparse_arch.state_dict()), + ) + + unsharded_model_preds = [] + for unsharded_rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_preds.append( + sparse_arch(kjt_input_per_rank[unsharded_rank])[0] + ) + + unsharded_model_pred_this_rank = unsharded_model_preds[ctx.rank] + + # sharded model + # each rank gets a subbatch + sharded_model_pred = sharded_sparse_arch(kjt_input_per_rank[ctx.rank])[0] + + torch.testing.assert_close( + sharded_model_pred.cpu(), unsharded_model_pred_this_rank.cpu() + ) + + torch.stack(unsharded_model_preds).mean().backward() + sharded_model_pred.mean().backward() + + unsharded_named_parameters = dict(sparse_arch.named_parameters()) + sharded_named_parameters = dict(sharded_sparse_arch.named_parameters()) + + for fqn, param in unsharded_named_parameters.items(): + if "_feature_processors" not in fqn: + continue + + replicated_param = sharded_named_parameters[fqn] + + torch.testing.assert_close( + # pyre-ignore + param.grad.cpu(), + replicated_param.grad.cpu(), + msg=f"Did not match for {fqn} {param.grad=} {replicated_param.grad=}", + ) + + assert ( + sparse_arch.state_dict().keys() == sharded_sparse_arch.state_dict().keys() + ), "State dict keys are not the same" + + +def _test_sharding_from_meta( # noqa C901 + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + use_fp_collection: bool = False, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + sparse_arch, sharded_sparse_arch = get_unsharded_and_sharded_module( + tables, + sharder, + use_dmp=True, + use_fp_collection=use_fp_collection, + init_device=torch.device("meta"), + ctx=ctx, + ) + + state_dict = sharded_sparse_arch.state_dict() + for key, param in state_dict.items(): + if "_feature_processors" not in key: + continue + assert not param.is_meta, f"Parameter {key} is still meta after sharding" + torch.testing.assert_close(param, torch.ones_like(param)) + + +@skip_if_asan_class +class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + # pyre-ignore + @given( + set_gradient_division=st.booleans(), + use_dmp=st.booleans(), + use_fp_collection=st.booleans(), + ) + def test_sharding_ebc( + self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool + ) -> None: + + import hypothesis + + # don't need to test entire matrix + hypothesis.assume(not (set_gradient_division and use_dmp)) + hypothesis.assume(not xor(use_dmp, use_fp_collection)) + + WORLD_SIZE = 2 + embedding_bag_config = get_configs() + kjt_input_per_rank = get_kjt_inputs() + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + kjt_input_per_rank=kjt_input_per_rank, + sharder=FeatureProcessedEmbeddingBagCollectionSharder(), + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), + set_gradient_division=set_gradient_division, + use_dmp=use_dmp, + use_fp_collection=use_fp_collection, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + # pyre-ignore + @given(use_fp_collection=st.booleans(), backend=st.sampled_from(["nccl", "gloo"])) + def test_sharding_fp_ebc_from_meta( + self, use_fp_collection: bool, backend: str + ) -> None: + embedding_bag_config = get_configs() + self._run_multi_process_test( + callable=_test_sharding_from_meta, + world_size=2, + tables=embedding_bag_config, + sharder=FeatureProcessedEmbeddingBagCollectionSharder(), + backend=backend, + use_fp_collection=use_fp_collection, + ) diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py b/torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py new file mode 100644 index 000000000..85932a0d8 --- /dev/null +++ b/torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast, OrderedDict + +import torch +import torch.nn as nn +from hypothesis import given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSingleRankBase, +) +from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( + create_module_and_freeze, + get_configs, + get_kjt_inputs, + TestFPEBCSharder, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import DataType + + +class FPModelParallelStateDictTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + self.use_fp_collection = True + + def _create_tables(self) -> None: + self.tables += get_configs() + + def _generate_batch(self) -> ModelInput: + kjt_input_per_rank = get_kjt_inputs() + batch = kjt_input_per_rank[0].to(self.device) + # pyre-ignore + return batch + + def _create_model(self) -> nn.Module: + return create_module_and_freeze( + tables=self.tables, + use_fp_collection=self.use_fp_collection, + device=torch.device("meta"), + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + use_fp_collection=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_load_state_dict( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + use_fp_collection: bool, + ) -> None: + self.use_fp_collection = use_fp_collection + sharders = [ + cast( + ModuleSharder[nn.Module], + TestFPEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + ), + ), + ] + models, batch = self._generate_dmps_and_batch( + sharders=sharders, + ) + m1, m2 = models + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_numerical_equivalence_between_kernel_types( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + self._set_table_weights_precision(dtype) + fused_params = { + "stochastic_rounding": stochastic_rounding, + "cache_precision": dtype, + } + + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + TestFPEBCSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED.value, + fused_params=fused_params, + ), + ), + ] + sharders = [ + cast( + ModuleSharder[nn.Module], + TestFPEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ), + ), + ] + + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (model, _), batch = self._generate_dmps_and_batch(sharders) + + # load the baseline model's state_dict onto the new model + model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + if is_training: + for _ in range(4): + self._train_models(fused_model, model, batch) + # skip eval here because it will cause numerical difference + # TODO figure out why + if not is_training or not stochastic_rounding: + self._eval_models( + fused_model, model, batch, is_deterministic=not stochastic_rounding + ) + self._compare_models( + fused_model, model, is_deterministic=not stochastic_rounding + ) diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py new file mode 100644 index 000000000..8efacdbb8 --- /dev/null +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Any, cast, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torchrec.distributed.fp_embeddingbag import ( + FeatureProcessedEmbeddingBagCollectionSharder, +) +from torchrec.distributed.test_utils.test_model import TestEBCSharder +from torchrec.distributed.types import QuantizedCommCodecs +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +DEFAULT_MAX_FEATURE_LENGTH = 12 + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingBagConfig], + use_fp_collection: bool, + device: torch.device, + max_feature_lengths: Optional[List[int]] = None, + ) -> None: + super().__init__() + + feature_names = [ + feature_name for table in tables for feature_name in table.feature_names + ] + + if max_feature_lengths is None: + max_feature_lengths = [DEFAULT_MAX_FEATURE_LENGTH] * len(feature_names) + + assert len(max_feature_lengths) == len( + feature_names + ), "Expect max_feature_lengths to have the same number of items as feature_names" + + self._fp_ebc: FeatureProcessedEmbeddingBagCollection = ( + FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=tables, + device=device, + is_weighted=True, + ), + ( + cast( + Dict[str, FeatureProcessor], + { + feature_name: PositionWeightedModule( + max_feature_length=max_feature_length + ) + for feature_name, max_feature_length in zip( + feature_names, max_feature_lengths + ) + }, + ) + if not use_fp_collection + else PositionWeightedModuleCollection( + max_feature_lengths=dict( + zip(feature_names, max_feature_lengths) + ), + ) + ), + ).to(device) + ) + + def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]: + fp_ebc_out = self._fp_ebc(kjt) + pred = torch.cat( + [ + fp_ebc_out[key] + for key in ["feature_0", "feature_1", "feature_2", "feature_3"] + ], + dim=1, + ) + loss = pred.mean() + return loss, pred + + +def create_module_and_freeze( + tables: List[EmbeddingBagConfig], + use_fp_collection: bool, + device: torch.device, + max_feature_lengths: Optional[List[int]] = None, +) -> SparseArch: + + sparse_arch = SparseArch(tables, use_fp_collection, device, max_feature_lengths) + + torch.manual_seed(0) + for param in sparse_arch.parameters(): + nn.init.normal_(param, mean=0, std=0.01) + torch.manual_seed(0) + + return sparse_arch + + +class TestFPEBCSharder(FeatureProcessedEmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + if fused_params is None: + fused_params = {} + + self._sharding_type = sharding_type + self._kernel_type = kernel_type + + ebc_sharder = TestEBCSharder( + self._sharding_type, + self._kernel_type, + fused_params, + qcomm_codecs_registry, + ) + super().__init__(ebc_sharder, qcomm_codecs_registry) + + def sharding_types(self, compute_device_type: str) -> List[str]: + """ + Restricts sharding to single type only. + """ + return ( + [self._sharding_type] + if self._sharding_type + in super().sharding_types(compute_device_type=compute_device_type) + else [] + ) + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + """ + Restricts to single impl. + """ + return [self._kernel_type] + + +def get_configs() -> List[EmbeddingBagConfig]: + dims = [3 * 16, 8, 8, 3 * 16] + return [ + EmbeddingBagConfig( + name=f"table_{i}", + feature_names=[f"feature_{i}"], + embedding_dim=dim, + num_embeddings=16, + ) + for i, dim in enumerate(dims) + ] + + +def get_kjt_inputs() -> List[KeyedJaggedTensor]: + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] None [2] + # "feature_1" [0, 1] None [2] + # "feature_2" [3, 1] [4,1] [5] + # "feature_3" [1] [6,1,8] [0,3,3] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0,1,2,3] + # "feature_1" [2, 3] None [2] + # "feature_2" [2, 7] [1,8,2] [8,1] + # "feature_3" [9] [8] [7] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2", "feature_3"], + values=torch.LongTensor( + [0, 1, 2, 0, 1, 2, 3, 1, 4, 1, 5, 1, 6, 1, 8, 0, 3, 3] + ), + lengths=torch.LongTensor( + [ + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 2, + 1, + 1, + 3, + 3, + ] + ), + weights=torch.FloatTensor( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2", "feature_3"], + values=torch.LongTensor( + [3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2, 2, 7, 1, 8, 2, 8, 1, 9, 8, 7] + ), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1, 2, 3, 2, 1, 1, 1]), + weights=torch.FloatTensor( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ), + ), + ] + return kjt_input_per_rank diff --git a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py b/torchrec/distributed/tests/test_fused_embedding_bag_collection.py deleted file mode 100644 index 0905c6638..000000000 --- a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Dict, List, Optional - -import hypothesis.strategies as st -import torch -import torch.nn as nn -from hypothesis import given, settings, Verbosity -from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - ParameterConstraints, - Topology, -) -from torchrec.distributed.test_utils.multi_process import ( - MultiProcessContext, - MultiProcessTestBase, -) - -from torchrec.distributed.test_utils.test_model import TestFusedEBCSharder -from torchrec.distributed.test_utils.test_sharding import copy_state_dict, SharderType -from torchrec.distributed.types import ( - ModuleSharder, - ShardingEnv, - ShardingPlan, - ShardingType, -) -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.fused_embedding_modules import ( - fuse_embedding_optimizer, - FusedEmbeddingBagCollection, -) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.test_utils import skip_if_asan_class - - -def sharding_single_rank( - rank: int, - world_size: int, - unsharded_model: nn.Module, - kjt_input: KeyedJaggedTensor, - sharders: List[ModuleSharder[nn.Module]], - backend: str, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - local_size: Optional[int] = None, -) -> None: - - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: - kjt_input = kjt_input.to(ctx.device) - unsharded_model = unsharded_model.to(ctx.device) - - # Shard model. - planner = EmbeddingShardingPlanner( - topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size - ), - constraints=constraints, - ) - plan: ShardingPlan = planner.collective_plan(unsharded_model, sharders, ctx.pg) - - sharded_model = DistributedModelParallel( - unsharded_model, - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) - - # Load model state from the global model. - copy_state_dict(sharded_model.state_dict(), unsharded_model.state_dict()) - - # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions - # in normal author modelling code this won't be an issue because each rank would individually create - # their model. output from sharded_pred is correctly on the correct device. - unsharded_model_pred = ( - unsharded_model(kjt_input).values().detach().clone().cpu() - ) - sharded_pred = sharded_model(kjt_input).values().detach().clone().cpu() - - # Compare predictions of sharded vs unsharded models. - torch.testing.assert_close(sharded_pred, unsharded_model_pred) - - -@skip_if_asan_class -class FusedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.COLUMN_WISE.value, - # ShardingType.DATA_PARALLEL.value, - # Data parallel checkpointing not yet supported - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_fused_ebc( - self, - sharder_type: str, - sharding_type: str, - ) -> None: - - fused_ebc = FusedEmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0", "feature_1"], - embedding_dim=8, - num_embeddings=10, - ) - ], - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - device=torch.device("cuda"), - ) - - # instance 0 instance 1 instance 2 - # "feature_0" [0, 1] None [2] - # "feature_1" [3] [4] [5,6,7] - # - - kjt_input = KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]), - lengths=torch.LongTensor([2, 0, 1, 1, 1, 3]), - ) - - self._run_multi_process_test( - callable=sharding_single_rank, - world_size=2, - unsharded_model=fused_ebc, - kjt_input=kjt_input, - sharders=[TestFusedEBCSharder(sharding_type=sharding_type)], - backend="nccl", - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.COLUMN_WISE.value, - # ShardingType.DATA_PARALLEL.value, - # Data parallel checkpointing not yet supported - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_fused_ebc_module_replace( - self, - sharding_type: str, - ) -> None: - - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0", "feature_1"], - embedding_dim=8, - num_embeddings=10, - ) - ], - ) - - fused_ebc = fuse_embedding_optimizer( - ebc, - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - device=torch.device("cuda"), - ) - - # instance 0 instance 1 instance 2 - # "feature_0" [0, 1] None [2] - # "feature_1" [3] [4] [5,6,7] - # - - kjt_input = KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]), - lengths=torch.LongTensor([2, 0, 1, 1, 1, 3]), - ) - - self._run_multi_process_test( - callable=sharding_single_rank, - world_size=2, - unsharded_model=fused_ebc, - kjt_input=kjt_input, - sharders=[TestFusedEBCSharder(sharding_type=sharding_type)], - backend="nccl", - ) diff --git a/torchrec/distributed/tests/test_fused_embedding_collection.py b/torchrec/distributed/tests/test_fused_embedding_collection.py deleted file mode 100644 index 91ae3f6af..000000000 --- a/torchrec/distributed/tests/test_fused_embedding_collection.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Dict, List, Optional - -import hypothesis.strategies as st -import torch -import torch.nn as nn -from hypothesis import given, settings, Verbosity -from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - ParameterConstraints, - Topology, -) -from torchrec.distributed.test_utils.multi_process import ( - MultiProcessContext, - MultiProcessTestBase, -) - -from torchrec.distributed.test_utils.test_model import TestFusedECSharder -from torchrec.distributed.test_utils.test_sharding import copy_state_dict -from torchrec.distributed.types import ( - ModuleSharder, - ShardingEnv, - ShardingPlan, - ShardingType, -) -from torchrec.modules.embedding_configs import EmbeddingConfig -from torchrec.modules.embedding_modules import EmbeddingCollection -from torchrec.modules.fused_embedding_modules import ( - fuse_embedding_optimizer, - FusedEmbeddingCollection, -) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.test_utils import skip_if_asan_class - - -@skip_if_asan_class -class FusedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): - @classmethod - def sharding_single_rank( - cls, - rank: int, - world_size: int, - unsharded_model: nn.Module, - kjt_input: KeyedJaggedTensor, - sharders: List[ModuleSharder[nn.Module]], - backend: str, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - local_size: Optional[int] = None, - ) -> None: - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: - kjt_input = kjt_input.to(ctx.device) - unsharded_model = unsharded_model.to(ctx.device) - - # Shard model. - planner = EmbeddingShardingPlanner( - topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size - ), - constraints=constraints, - ) - plan: ShardingPlan = planner.collective_plan( - unsharded_model, sharders, ctx.pg - ) - - sharded_model = DistributedModelParallel( - unsharded_model, - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) - - # Load model state from the global model. - copy_state_dict(sharded_model.state_dict(), unsharded_model.state_dict()) - - # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions - # in normal author modelling code this won't be an issue because each rank would individually create - # their model. output from sharded_pred is correctly on the correct device. - unsharded_model_pred = unsharded_model(kjt_input) - sharded_pred = sharded_model(kjt_input) - - assert set(unsharded_model_pred.keys()) == set(sharded_pred.keys()) - - for feature_name in unsharded_model_pred.keys(): - unsharded_jt = unsharded_model_pred[feature_name] - sharded_jt = sharded_pred[feature_name] - - torch.testing.assert_close( - unsharded_jt.values().cpu(), sharded_jt.values().cpu() - ) - torch.testing.assert_close( - unsharded_jt.lengths().cpu(), sharded_jt.lengths().cpu() - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.COLUMN_WISE.value, - # ShardingType.DATA_PARALLEL.value, - # Data parallel checkpointing not yet supported - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_fused_ec( - self, - sharding_type: str, - ) -> None: - - fused_ec = FusedEmbeddingCollection( - tables=[ - EmbeddingConfig( - name="table_0", - feature_names=["feature_0", "feature_1"], - embedding_dim=8, - num_embeddings=10, - ) - ], - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - device=torch.device("cuda"), - ) - - # instance 0 instance 1 instance 2 - # "feature_0" [0, 1] None [2] - # "feature_1" [3] [4] [5,6,7] - # - - kjt_input = KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]), - lengths=torch.LongTensor([2, 0, 1, 1, 1, 3]), - ) - - self._run_multi_process_test( - callable=self.sharding_single_rank, - world_size=2, - unsharded_model=fused_ec, - kjt_input=kjt_input, - sharders=[TestFusedECSharder(sharding_type=sharding_type)], - backend="nccl", - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.COLUMN_WISE.value, - # ShardingType.DATA_PARALLEL.value, - # Data parallel checkpointing not yet supported - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_fused_ebc_module_replace( - self, - sharding_type: str, - ) -> None: - - ec = EmbeddingCollection( - tables=[ - EmbeddingConfig( - name="table_0", - feature_names=["feature_0", "feature_1"], - embedding_dim=8, - num_embeddings=10, - ) - ], - ) - - fused_ec = fuse_embedding_optimizer( - ec, - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - device=torch.device("cuda"), - ) - - # instance 0 instance 1 instance 2 - # "feature_0" [0, 1] None [2] - # "feature_1" [3] [4] [5,6,7] - # - - kjt_input = KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]), - lengths=torch.LongTensor([2, 0, 1, 1, 1, 3]), - ) - - self._run_multi_process_test( - callable=self.sharding_single_rank, - world_size=2, - unsharded_model=fused_ec, - kjt_input=kjt_input, - sharders=[TestFusedECSharder(sharding_type=sharding_type)], - backend="nccl", - ) diff --git a/torchrec/distributed/tests/test_fused_optim.py b/torchrec/distributed/tests/test_fused_optim.py deleted file mode 100644 index 9203dd8a1..000000000 --- a/torchrec/distributed/tests/test_fused_optim.py +++ /dev/null @@ -1,300 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -from typing import cast, Dict, List, Optional, Union - -import hypothesis.strategies as st -import torch -import torch.distributed as dist -import torch.nn as nn -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from hypothesis import given, settings, Verbosity -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - EmbeddingTableConfig, -) -from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.planner import ParameterConstraints -from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase -from torchrec.distributed.test_utils.test_model import ( - _get_default_rtol_and_atol, - TestEBCSharder, - TestEBSharder, - TestSparseNN, - TestSparseNNBase, -) -from torchrec.distributed.test_utils.test_sharding import ( - copy_state_dict, - gen_full_pred_after_one_step, - gen_model_and_input, -) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType -from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig -from torchrec.test_utils import ( - init_distributed_single_host, - seed_and_log, - skip_if_asan_class, -) - - -def create_test_sharder( - sharding_type: str, kernel_type: str, optim: EmbOptimType -) -> Union[TestEBSharder, TestEBCSharder]: - fused_params = {} - fused_params["optimizer"] = optim - if optim == EmbOptimType.EXACT_SGD: - fused_params["learning_rate"] = 0.1 - else: - fused_params["learning_rate"] = 0.01 - return TestEBCSharder(sharding_type, kernel_type, fused_params) - - -@skip_if_asan_class -class ModelParallelTest(MultiProcessTestBase): - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharding_type=st.sampled_from( - [ - ShardingType.ROW_WISE.value, - ShardingType.TABLE_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.FUSED.value, - ] - ), - optim_type=st.sampled_from( - [ - EmbOptimType.EXACT_SGD, - EmbOptimType.EXACT_ROWWISE_ADAGRAD, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) - def test_sharding_nccl_rw( - self, - sharding_type: str, - kernel_type: str, - optim_type: EmbOptimType, - ) -> None: - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder(sharding_type, kernel_type, optim_type), - ], - backend="nccl", - optim=optim_type, - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.FUSED.value, - ] - ), - optim_type=st.sampled_from( - [ - EmbOptimType.EXACT_SGD, - EmbOptimType.EXACT_ROWWISE_ADAGRAD, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) - def test_sharding_nccl_tw( - self, - sharding_type: str, - kernel_type: str, - optim_type: EmbOptimType, - ) -> None: - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder(sharding_type, kernel_type, optim_type), - ], - backend="nccl", - optim=optim_type, - ) - - @seed_and_log - def setUp(self) -> None: - super().setUp() - torch.use_deterministic_algorithms(True) - if torch.cuda.is_available(): - torch.backends.cudnn.allow_tf32 = False - torch.backends.cuda.matmul.allow_tf32 = False - - num_features = 4 - num_weighted_features = 2 - - self.tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 4, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_features) - ] - self.weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 2) * 4, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - - self.embedding_groups = { - "group_0": ["feature_" + str(i) for i in range(num_features)] - } - - def _test_sharding( - self, - sharders: List[ModuleSharder[nn.Module]], - optim: EmbOptimType, - backend: str = "gloo", - world_size: int = 2, - local_size: Optional[int] = None, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - ) -> None: - self._run_multi_process_test( - callable=self._test_optim_single_rank, - world_size=world_size, - local_size=local_size, - model_class=TestSparseNN, - tables=self.tables, - weighted_tables=self.weighted_tables, - embedding_groups=self.embedding_groups, - sharders=sharders, - backend=backend, - optim=optim, - constraints=constraints, - ) - - @classmethod - def _test_optim_single_rank( - cls, - rank: int, - world_size: int, - model_class: TestSparseNNBase, - embedding_groups: Dict[str, List[str]], - tables: List[EmbeddingTableConfig], - sharders: List[ModuleSharder[nn.Module]], - backend: str, - optim: EmbOptimType, - weighted_tables: Optional[List[EmbeddingTableConfig]] = None, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - local_size: Optional[int] = None, - ) -> None: - # Override local_size after pg construction because unit test device count - # is larger than local_size setup. This can be problematic for twrw because - # we have ShardedTensor placement check. - os.environ["LOCAL_WORLD_SIZE"] = str(world_size) - if backend == "nccl": - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - pg = init_distributed_single_host( - rank=rank, - world_size=world_size, - backend=backend, - local_size=local_size, - ) - if rank == 0: - global_pg = dist.new_group(ranks=[0], backend=backend) - dist.new_group(ranks=[1], backend=backend) - else: - dist.new_group(ranks=[0], backend=backend) - global_pg = dist.new_group(ranks=[1], backend=backend) - - # Generate model & inputs. - (global_model, inputs) = gen_model_and_input( - model_class=model_class, - tables=tables, - weighted_tables=weighted_tables, - embedding_groups=embedding_groups, - world_size=world_size, - num_float_features=16, - ) - global_model = global_model.cuda(0) - global_model = DistributedModelParallel( - global_model, - env=ShardingEnv.from_process_group(global_pg), - sharders=sharders, - device=torch.device("cuda:0"), - init_data_parallel=False, - ) - global_input = inputs[0][0].to(torch.device("cuda:0")) - local_input = inputs[0][1][rank].to(device) - - # Run single step of unsharded model to populate optimizer states. - global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) - gen_full_pred_after_one_step(global_model, global_opt, global_input) - - # Shard model. - local_model = model_class( - tables=cast(List[BaseEmbeddingConfig], tables), - weighted_tables=weighted_tables, - embedding_groups=embedding_groups, - dense_device=device, - sparse_device=torch.device("meta"), - num_float_features=16, - ) - local_model = DistributedModelParallel( - local_model, - env=ShardingEnv.from_process_group(pg), - sharders=sharders, - device=device, - ) - local_opt = torch.optim.SGD(local_model.parameters(), lr=0.1) - - # Load model & optimizer states from the global model. - copy_state_dict(local_model.state_dict(), global_model.state_dict()) - for param_name, local_state in local_model.fused_optimizer.state_dict()[ - "state" - ].items(): - global_state = global_model.fused_optimizer.state_dict()["state"][ - param_name - ] - copy_state_dict(local_state, global_state) - - # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step(local_model, local_opt, local_input) - all_local_pred = [] - for _ in range(world_size): - all_local_pred.append(torch.empty_like(local_pred)) - dist.all_gather(all_local_pred, local_pred, group=pg) - - # Run second training step of the unsharded model. - global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) - global_pred = gen_full_pred_after_one_step( - global_model, global_opt, global_input - ) - - # Compare predictions of sharded vs unsharded models. - actual, expected = global_pred.cpu(), torch.cat(all_local_pred).cpu() - rtol, atol = _get_default_rtol_and_atol(actual, expected) - torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol) diff --git a/torchrec/distributed/tests/test_fused_optimizer.py b/torchrec/distributed/tests/test_fused_optimizer.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/tests/test_fx_jit.py b/torchrec/distributed/tests/test_fx_jit.py new file mode 100644 index 000000000..e89d67013 --- /dev/null +++ b/torchrec/distributed/tests/test_fx_jit.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import unittest +from dataclasses import dataclass +from enum import Enum + +from typing import cast, List, Tuple + +import torch +from torch.distributed import ProcessGroup +from torchrec import EmbeddingCollection, EmbeddingConfig +from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + KJTList, + ListOfKJTList, + ModuleSharder, + ShardingType, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.infer_utils import ( + assert_close, + KJTInputWrapper, + model_input_to_forward_args, + model_input_to_forward_args_kjt, + prep_inputs, + quantize, + TestModelInfo, + TestQuantEBCSharder, + TestQuantECSharder, + TorchTypesModelInputWrapper, +) +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import Awaitable, ShardingEnv +from torchrec.fx.tracer import Tracer as TorchrecFxTracer +from torchrec.fx.utils import fake_range +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +class FxJitTestType(Enum): + CREATE_ONLY = 0 + FX = 1 + FX_JIT = 2 + + +@dataclass +class Context: + process_group: ProcessGroup + + +class ModelTraceScriptTest(unittest.TestCase): + def _set_up_qebc( + self, sharding_type: str, quant_state_dict_split_scale_bias: bool + ) -> TestModelInfo: + local_device = torch.device("cuda:0") + model_info = TestModelInfo( + sparse_device=local_device, + dense_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=2, + ) + + model_info.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=512, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(model_info.num_features) + ] + model_info.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(model_info.num_weighted_features) + ] + model_info.model = TorchTypesModelInputWrapper( + TestSparseNN( + tables=model_info.tables, + weighted_tables=model_info.weighted_tables, + num_float_features=model_info.num_float_features, + dense_device=model_info.dense_device, + sparse_device=model_info.sparse_device, + ) + ) + + model_info.model.training = False + model_info.quant_model = quantize( + model_info.model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) + + model_info.sharders = [ + cast( + ModuleSharder[torch.nn.Module], + TestQuantEBCSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in model_info.tables], + ), + ), + cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder()), + ] + + return model_info + + def _set_up_qec( + self, sharding_type: str, quant_state_dict_split_scale_bias: bool + ) -> TestModelInfo: + local_device = torch.device("cuda:0") + model_info = TestModelInfo( + sparse_device=local_device, + dense_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + ) + model_info.tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(model_info.num_features) + ] + + model_info.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=model_info.tables, + device=model_info.sparse_device, + ) + ) + ) + + model_info.model.training = False + model_info.quant_model = quantize( + model_info.model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) + + model_info.sharders = [ + cast( + ModuleSharder[torch.nn.Module], + TestQuantECSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.QUANT.value, + ), + ) + ] + + return model_info + + def shard_modules_QEBC( + self, + world_size: int, + sharding_type: str, + quant_state_dict_split_scale_bias: bool, + # pyre-ignore + ) -> Tuple[torch.nn.Module, torch.nn.Module, List[Tuple]]: + model_info = self._set_up_qebc(sharding_type, quant_state_dict_split_scale_bias) + sharded_model = _shard_modules( + module=model_info.quant_model, + sharders=model_info.sharders, + device=model_info.sparse_device, + env=ShardingEnv.from_local(world_size=world_size, rank=0), + ) + + inputs = prep_inputs(model_info, world_size, long_indices=False) + + return ( + model_info.quant_model, + sharded_model, + [ + model_input_to_forward_args(inp.to(model_info.sparse_device)) + for inp in inputs + ], + ) + + def shard_modules_QEC( + self, + world_size: int, + sharding_type: str, + quant_state_dict_split_scale_bias: bool, + # pyre-ignore + ) -> Tuple[torch.nn.Module, torch.nn.Module, List[Tuple]]: + model_info = self._set_up_qec(sharding_type, quant_state_dict_split_scale_bias) + sharded_model = _shard_modules( + module=model_info.quant_model, + sharders=model_info.sharders, + device=model_info.sparse_device, + env=ShardingEnv.from_local(world_size=world_size, rank=0), + ) + + inputs = prep_inputs(model_info, world_size, long_indices=False) + + return ( + model_info.quant_model, + sharded_model, + [ + model_input_to_forward_args_kjt(inp.to(model_info.sparse_device)) + for inp in inputs + ], + ) + + def DMP_QEBC( + self, + world_size: int, + sharding_type: str, + quant_state_dict_split_scale_bias: bool, + unwrap_dmp: bool, + # pyre-ignore + ) -> Tuple[torch.nn.Module, torch.nn.Module, List[Tuple]]: + model_info = self._set_up_qebc(sharding_type, quant_state_dict_split_scale_bias) + topology = Topology(world_size=world_size, compute_device="cuda") + plan = EmbeddingShardingPlanner( + topology=topology, + batch_size=10, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=1, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ).plan(model_info.quant_model, model_info.sharders) + + dmp = DistributedModelParallel( + model_info.quant_model, + plan=plan, + sharders=model_info.sharders, + device=model_info.sparse_device, + env=ShardingEnv.from_local(world_size=world_size, rank=0), + init_data_parallel=False, + ) + + dmp = dmp.copy(model_info.sparse_device) + + inputs = prep_inputs(model_info, world_size, long_indices=False) + + m = dmp.module if unwrap_dmp else dmp + return ( + model_info.quant_model, + m, + [ + model_input_to_forward_args(inp.to(model_info.sparse_device)) + for inp in inputs + ], + ) + + def DMP_QEC( + self, + world_size: int, + sharding_type: str, + quant_state_dict_split_scale_bias: bool, + sharding_enabled: bool, + # pyre-ignore + ) -> Tuple[torch.nn.Module, torch.nn.Module, List[Tuple]]: + model_info = self._set_up_qec(sharding_type, quant_state_dict_split_scale_bias) + + if sharding_enabled: + topology = Topology(world_size=world_size, compute_device="cuda") + plan = EmbeddingShardingPlanner( + topology=topology, + batch_size=10, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=1, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ).plan(model_info.quant_model, model_info.sharders) + m = DistributedModelParallel( + model_info.quant_model, + plan=plan, + sharders=model_info.sharders, + device=model_info.sparse_device, + env=ShardingEnv.from_local(world_size=world_size, rank=0), + init_data_parallel=False, + ) + model_info.model = m.module + + inputs = prep_inputs(model_info, world_size, long_indices=False) + + return ( + model_info.quant_model, + model_info.model, + [ + model_input_to_forward_args_kjt(inp.to(model_info.sparse_device)) + for inp in inputs + ], + ) + + def _models_with_inputs( + self, + # pyre-ignore + *args, + # pyre-ignore + **kwargs, + # pyre-ignore + ) -> List[Tuple[torch.nn.Module, torch.nn.Module, List[Tuple], FxJitTestType]]: + return [ + (*fn(*args, **kwargs), test_type) + for fn, test_type in [ + ( + lambda world_size: self.DMP_QEBC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + unwrap_dmp=True, # preferred usage is to provide fx trace with unwrapped dmp + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.DMP_QEBC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + unwrap_dmp=False, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.DMP_QEBC( + world_size=world_size, + sharding_type=ShardingType.ROW_WISE.value, + quant_state_dict_split_scale_bias=True, + unwrap_dmp=False, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.DMP_QEC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + sharding_enabled=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.DMP_QEC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=True, + sharding_enabled=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.DMP_QEC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + sharding_enabled=False, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEBC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEBC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEBC( + world_size=world_size, + sharding_type=ShardingType.ROW_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEBC( + world_size=world_size, + sharding_type=ShardingType.COLUMN_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=False, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEC( + world_size=world_size, + sharding_type=ShardingType.TABLE_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEC( + world_size=world_size, + sharding_type=ShardingType.ROW_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ( + lambda world_size: self.shard_modules_QEC( + world_size=world_size, + sharding_type=ShardingType.COLUMN_WISE.value, + quant_state_dict_split_scale_bias=True, + ), + FxJitTestType.FX_JIT, + ), + ] + ] + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_fxtrace_jitscript(self) -> None: + for non_sharded_model, model, inputs, test_type in self._models_with_inputs( + world_size=2, + ): + # We need more than one input to verify correctness of tracing and scripting using input different from what was used for tracing + assert len(inputs) > 1 + + # Run model first time to go through lazy initialized blocks before tracing + # Targeting only inference for this time + non_sharded_model(*inputs[0]) + eager_output = model(*inputs[0]) + + if test_type == FxJitTestType.CREATE_ONLY: + continue + + tracer = TorchrecFxTracer() + graph = tracer.trace(model) + + # pyre-ignore + gm = torch.fx.GraphModule(tracer.root, graph) + + if test_type == FxJitTestType.FX_JIT: + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + + if isinstance(eager_output, Awaitable): + eager_output = eager_output.wait() + + assert_close(eager_output, gm_script_output) + + for inp in inputs[1:]: + eager_output = model(*inp) + script_output = gm_script(*inp) + assert_close(eager_output, script_output) + + def test_jitscript(self) -> None: + # Check main types to be torch jit scriptable + for clz in [ + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, + KJTList, + ListOfKJTList, + ]: + # Using torch.jit._script._recursive_compile_class instead of torch.jit.script + # As classes later is more restrictive, checking no inheritance + # (e.g. Multistreamable which we so far do not need in jit script) etc. + # We need those classes not as it is, but as composable blocks in model. + # _recursive_compile_class for that is enough + torch.jit._script._recursive_compile_class(clz, fake_range()) + torch.jit.script(KeyedJaggedTensor.from_jt_dict) + + def test_jitscript_kjt(self) -> None: + def kjt_split(segments: List[int]) -> List[KeyedJaggedTensor]: + kjt = KeyedJaggedTensor( + keys=["a", "b", "c"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + lengths=torch.tensor([2, 0, 1, 1, 1, 2]), + ) + return kjt.split(segments) + + sm = torch.jit.script(kjt_split) + sm([1, 0, 2, 0]) diff --git a/torchrec/distributed/tests/test_infer_hetero_shardings.py b/torchrec/distributed/tests/test_infer_hetero_shardings.py new file mode 100755 index 000000000..ed3e8288d --- /dev/null +++ b/torchrec/distributed/tests/test_infer_hetero_shardings.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import unittest + +import hypothesis.strategies as st + +import torch +from hypothesis import given, settings +from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder +from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + row_wise, + table_wise, +) +from torchrec.distributed.test_utils.infer_utils import KJTInputWrapper, quantize +from torchrec.distributed.types import ShardingEnv, ShardingPlan +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class InferHeteroShardingsTest(unittest.TestCase): + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + sharding_device=st.sampled_from(["cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_sharder_different_world_sizes_for_qec(self, sharding_device: str) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + local_size = 1 + tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=tables, + device=torch.device("cpu"), + ) + ) + ) + non_sharded_model = quantize( + model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=torch.qint8, + ) + sharder = QuantEmbeddingCollectionSharder() + compute_kernel = EmbeddingComputeKernel.QUANT.value + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + non_sharded_model._module_kjt_input[0], + per_param_sharding={ + "table_0": row_wise(([20, 10, 100], "cpu")), + "table_1": table_wise( + rank=0, device="cuda", compute_kernel=compute_kernel + ), + "table_2": table_wise( + rank=1, device="cuda", compute_kernel=compute_kernel + ), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan}) + env_dict = { + "cpu": ShardingEnv.from_local( + 3, + 0, + ), + "cuda": ShardingEnv.from_local( + 2, + 0, + ), + } + dummy_input = ( + ["feature_0", "feature_1", "feature_2"], + torch.tensor([1, 1, 1]), + None, + torch.tensor([1, 1, 1]), + None, + ) + + sharded_model = _shard_modules( + module=non_sharded_model, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[QuantEmbeddingCollectionSharder]`. + sharders=[sharder], + device=torch.device(sharding_device), + plan=plan, + env=env_dict, + ) + + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self.assertTrue(hasattr(sharded_model._module_kjt_input[0], "_lookups")) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self.assertTrue(len(sharded_model._module_kjt_input[0]._lookups) == 2) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self.assertTrue(hasattr(sharded_model._module_kjt_input[0], "_input_dists")) + + for i, env in enumerate(env_dict.values()): + self.assertTrue( + hasattr( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + sharded_model._module_kjt_input[0]._lookups[i], + "_embedding_lookups_per_rank", + ) + ) + self.assertTrue( + len( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + sharded_model._module_kjt_input[0] + ._lookups[i] + ._embedding_lookups_per_rank + ) + == env.world_size + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + def test_sharder_different_world_sizes_for_qebc(self) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + local_size = 1 + tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingBagCollection( + tables=tables, + device=torch.device("cpu"), + ) + ) + ) + non_sharded_model = quantize( + model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=torch.qint8, + ) + sharder = QuantEmbeddingBagCollectionSharder() + compute_kernel = EmbeddingComputeKernel.QUANT.value + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + non_sharded_model._module_kjt_input[0], + per_param_sharding={ + "table_0": row_wise(([20, 10, 100], "cpu")), + "table_1": table_wise( + rank=0, device="cuda", compute_kernel=compute_kernel + ), + "table_2": table_wise( + rank=1, device="cuda", compute_kernel=compute_kernel + ), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan}) + env_dict = { + "cpu": ShardingEnv.from_local( + 3, + 0, + ), + "cuda": ShardingEnv.from_local( + 2, + 0, + ), + } + sharded_model = _shard_modules( + module=non_sharded_model, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[QuantEmbeddingBagCollectionSharder]`. + sharders=[sharder], + device=torch.device("cpu"), + plan=plan, + env=env_dict, + ) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self.assertTrue(hasattr(sharded_model._module_kjt_input[0], "_lookups")) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self.assertTrue(len(sharded_model._module_kjt_input[0]._lookups) == 2) + for i, env in enumerate(env_dict.values()): + self.assertTrue( + hasattr( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + sharded_model._module_kjt_input[0]._lookups[i], + "_embedding_lookups_per_rank", + ) + ) + self.assertTrue( + len( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + sharded_model._module_kjt_input[0] + ._lookups[i] + ._embedding_lookups_per_rank + ) + == env.world_size + ) + + def test_cpu_gpu_sharding_autoplanner(self) -> None: + num_embeddings = 10 + emb_dim = 16 + tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=tables, + device=torch.device("cpu"), + ) + ) + ) + non_sharded_model = quantize( + model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=torch.qint8, + ) + sharder = QuantEmbeddingCollectionSharder() + topo_cpu = Topology(world_size=3, compute_device="cpu") + topo_gpu = Topology(world_size=2, compute_device="cuda") + topo_groups = { + "cpu": topo_cpu, + "cuda": topo_gpu, + } + constraints = { + "table_0": ParameterConstraints(device_group="cpu"), + "table_1": ParameterConstraints(device_group="cuda"), + "table_2": ParameterConstraints(device_group="cuda"), + } + planner = HeteroEmbeddingShardingPlanner( + topology_groups=topo_groups, constraints=constraints + ) + module_plan = planner.plan( + non_sharded_model, + # pyre-ignore + sharders=[sharder], + ) + print(module_plan) + + self.assertTrue( + # pyre-ignore + module_plan.plan["_module_kjt_input.0"]["table_0"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cpu", + ) + self.assertTrue( + module_plan.plan["_module_kjt_input.0"]["table_1"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cuda", + ) + self.assertTrue( + module_plan.plan["_module_kjt_input.0"]["table_2"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cuda", + ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py new file mode 100755 index 000000000..c39540714 --- /dev/null +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -0,0 +1,2411 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import math +import unittest +from typing import Dict, List, Tuple + +import hypothesis.strategies as st + +import torch +from hypothesis import given, settings +from torchrec import ( + EmbeddingBagConfig, + EmbeddingCollection, + EmbeddingConfig, + KeyedJaggedTensor, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.infer_utils import ( + get_path_device_tuples, + get_tbes_from_sharded_module, +) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder +from torchrec.distributed.quant_embeddingbag import ( + QuantEmbeddingBagCollectionSharder, + QuantFeatureProcessedEmbeddingBagCollectionSharder, +) +from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + placement, + row_wise, + table_wise, +) +from torchrec.distributed.test_utils.infer_utils import ( + assert_close, + assert_weight_spec, + create_cw_min_partition_constraints, + create_test_model, + KJTInputWrapper, + model_input_to_forward_args, + model_input_to_forward_args_kjt, + prep_inputs, + quantize, + quantize_fpebc, + shard_qebc, + shard_qec, + TestModelInfo, + TestQuantEBCSharder, + TestQuantECSharder, +) +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.types import ShardingEnv, ShardingPlan +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessorsCollection, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) + +torch.fx.wrap("len") + + +class TimeGapPoolingCollectionModule(FeatureProcessorsCollection): + def __init__( + self, + feature_pow: float, + feature_min: float, + feature_max: float, + device: torch.device, + ) -> None: + super().__init__() + self.feature_min = feature_min + self.feature_max = feature_max + self.feature_pow = feature_pow + self.device = device + + param = torch.empty( + [math.ceil(math.pow(feature_max, feature_pow)) + 2], + device=device, + ) + self.register_buffer("w", param) + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + scores_list = [] + for feature_name in features.keys(): + jt = features[feature_name] + scores = jt.weights() + scores = torch.clamp( + scores, + min=self.feature_min, + max=self.feature_max, + ) + indices = torch.floor(torch.pow(scores, self.feature_pow)) + indices = indices.to(torch.int32) + scores = torch.index_select(self.w, 0, indices) + scores_list.append(scores) + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=( + torch.cat(scores_list) if scores_list else features.weights_or_none() + ), + lengths=features.lengths(), + stride=features.stride(), + ) + + +def placement_helper(device_type: str, index: int = 0) -> str: + if device_type == "cpu": + return f"rank:0/{device_type}" # cpu only use rank 0 + + return f"rank:{index}/{device_type}:{index}" + + +class InferShardingsTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + set_propogate_device(True) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_tw(self, weight_dtype: torch.dtype, device_type: str) -> None: + num_embeddings = 256 + emb_dim = 16 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim), + placement(device_type, 0, world_size), + ), + ] + ] + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=expected_shards, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(sharded_output, non_sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.TABLE_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8]), + device_type=st.sampled_from(["cuda"]), + ) + @settings(max_examples=4, deadline=None) + def test_tw_ebc_full_rank_weighted_ebc_with_empty_rank( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 256 + emb_dim = 16 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + num_features=6, # 6 sparse features on ebc + num_weighted_features=1, # only 1 weighted sparse feature on weighted_ebc + ) + + non_sharded_model = mi.quant_model + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=None, + shard_score_ebc=True, + ) + + self.assertEqual( + len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse`. + sharded_model._module.sparse.ebc._lookups[0]._embedding_lookups_per_rank + ), + 2, + ) + self.assertEqual( + len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse`. + sharded_model._module.sparse.weighted_ebc._lookups[ + 0 + ]._embedding_lookups_per_rank + ), + 1, + ) + + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(sharded_output, non_sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_rw(self, weight_dtype: torch.dtype, device_type: str) -> None: + num_embeddings = 256 + emb_dim = 16 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + num_emb_half = num_embeddings // 2 + expected_shards = [ + [ + ((0, 0, num_emb_half, emb_dim), placement(device_type, 0, world_size)), + ( + (num_emb_half, 0, num_emb_half, emb_dim), + placement(device_type, 1, world_size), + ), + ] + ] + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(sharded_output, non_sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.ROW_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + test_permute=st.booleans(), + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_cw( + self, test_permute: bool, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 64 + emb_dim = 512 + emb_dim_4 = emb_dim // 4 + local_size = 2 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_ranks: List[int] = [0, 1, 0, 1] if not test_permute else [1, 0, 1, 0] + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim_4), + placement(device_type, expected_ranks[0], world_size), + ), + ( + (0, 1 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, expected_ranks[1], world_size), + ), + ( + (0, 2 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, expected_ranks[2], world_size), + ), + ( + (0, 3 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, expected_ranks[3], world_size), + ), + ] + ] + + plan = None + if test_permute: + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.COLUMN_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": column_wise(ranks=[1, 0, 1, 0]), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + device_type=device_type, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi=mi, + sharding_type=ShardingType.COLUMN_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + # torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.COLUMN_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + emb_dim=st.sampled_from([192, 128]), + test_permute=st.booleans(), + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_cw_with_smaller_emb_dim( + self, + emb_dim: int, + test_permute: bool, + weight_dtype: torch.dtype, + device_type: str, + ) -> None: + num_embeddings = 64 + emb_dim_4 = emb_dim // 4 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + constraints = create_cw_min_partition_constraints([("table_0", emb_dim_4)]) + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + constraints=constraints, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim_4), + placement(device_type, 0, world_size), + ), + ( + (0, 1 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 1, world_size), + ), + ( + (0, 2 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 0, world_size), + ), + ( + (0, 3 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 1, world_size), + ), + ] + ] + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.COLUMN_WISE, + device=local_device, + expected_shards=expected_shards, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + # torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.COLUMN_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_cw_multiple_tables_with_permute( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 64 + emb_dim = 512 + emb_dim_2 = 512 // 2 + local_size = 2 + world_size = 2 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + num_features=2, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim_2), + placement(device_type, 1, world_size), + ), + ( + (0, 1 * emb_dim_2, num_embeddings, emb_dim_2), + placement(device_type, 0, world_size), + ), + ], + [ + ( + (0, 0, num_embeddings, emb_dim_2), + placement(device_type, 0, world_size), + ), + ( + (0, 1 * emb_dim_2, num_embeddings, emb_dim_2), + placement(device_type, 1, world_size), + ), + ], + ] + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.COLUMN_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": column_wise(ranks=[1, 0]), + "table_1": column_wise(ranks=[0, 1]), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + device_type=device_type, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.COLUMN_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + # torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0", "table_1"], + ShardingType.COLUMN_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_cw_irregular_shard_placement( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 64 + emb_dim = 384 + emb_dim_2 = emb_dim // 2 + emb_dim_3 = emb_dim // 3 + local_size = 4 + world_size = 4 + batch_size = 4 + local_device = torch.device(f"{device_type}:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + num_features=3, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim_2), + placement(device_type, 2, world_size), + ), + ( + (0, 1 * emb_dim_2, num_embeddings, emb_dim_2), + placement(device_type, 1, world_size), + ), + ], + [ + ( + (0, 0, num_embeddings, emb_dim_2), + placement(device_type, 0, world_size), + ), + ( + (0, 1 * emb_dim_2, num_embeddings, emb_dim_2), + placement(device_type, 3, world_size), + ), + ], + [ + ( + (0, 0, num_embeddings, emb_dim_3), + placement(device_type, 0, world_size), + ), + ( + (0, 1 * emb_dim_3, num_embeddings, emb_dim_3), + placement(device_type, 2, world_size), + ), + ( + (0, 2 * emb_dim_3, num_embeddings, emb_dim_3), + placement(device_type, 3, world_size), + ), + ], + ] + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.COLUMN_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": column_wise(ranks=[2, 1]), + "table_1": column_wise(ranks=[0, 3]), + "table_2": column_wise(ranks=[0, 2, 3]), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + device_type=device_type, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.COLUMN_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + # torchrec.distributed.test_utils.test_sharding.copy_state_dict(sharded_model.state_dict(), non_sharded_model.state_dict()) does not work for CW due to non-trivial qscaleshift copy which is handled in shardedQEBC load_state_dict + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + device_type_weight_dtype=st.sampled_from( + [ + ("cuda", torch.qint8), + ("cuda", torch.quint4x2), + # ("cpu", torch.qint8), column sharding is currently not supported in CPU inference sharding + # ("cpu", torch.quint4x2), + ] + ), + ) + @settings(max_examples=4, deadline=None) + def test_cw_sequence( + self, device_type_weight_dtype: Tuple[str, torch.dtype] + ) -> None: + device_type, weight_dtype = device_type_weight_dtype + num_embeddings = 4 + emb_dim = 512 + emb_dim_4 = emb_dim // 4 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim_4), + placement(device_type, 0, world_size), + ), + ( + (0, 1 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 1, world_size), + ), + ( + (0, 2 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 0, world_size), + ), + ( + (0, 3 * emb_dim_4, num_embeddings, emb_dim_4), + placement(device_type, 1, world_size), + ), + ], + ] * 2 + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.COLUMN_WISE, # column wise sharding the model + device=local_device, + expected_shards=expected_shards, + ) + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(sharded_output, non_sharded_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module_kjt_input.0", + "embeddings", + ["table_0", "table_1"], + ShardingType.COLUMN_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_tw_sequence(self, weight_dtype: torch.dtype, device_type: str) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, num_embeddings, emb_dim), + placement(device_type, 0, world_size), + ), + ], + [ + ( + (0, 0, num_embeddings, emb_dim), + placement(device_type, 1, world_size), + ), + ], + ] + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=expected_shards, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module_kjt_input.0", + "embeddings", + ["table_0", "table_1"], + ShardingType.TABLE_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + device_type_weight_dtype=st.sampled_from( + [ + ("cuda", torch.qint8), + ("cuda", torch.quint4x2), + ("cpu", torch.qint8), + ("cpu", torch.quint4x2), + ] + ), + ) + @settings(max_examples=4, deadline=None) + def test_rw_sequence( + self, device_type_weight_dtype: Tuple[str, torch.dtype] + ) -> None: + device_type, weight_dtype = device_type_weight_dtype + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + num_emb_half = num_embeddings // 2 + expected_shards = [ + [ + ((0, 0, num_emb_half, emb_dim), placement(device_type, 0, world_size)), + ( + (num_emb_half, 0, num_emb_half, emb_dim), + placement(device_type, 1, world_size), + ), + ], + ] * 2 + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module_kjt_input.0", + "embeddings", + ["table_0", "table_1"], + ShardingType.ROW_WISE.value, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> None: + num_embeddings = 512 + emb_dim = 64 + world_size = 4 + local_size = 4 + batch_size = 4 + local_device = torch.device("cuda:0" if device == "cuda" else device) + + topology: Topology = Topology(world_size=world_size, compute_device=device) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=4, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, 256, 64), + placement_helper(device, 0), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (384, 0, 64, 64), + placement_helper(device, 2), + ), + ( + (448, 0, 64, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 128, 64), + placement_helper(device, 0), + ), + ( + (128, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (384, 0, 128, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 256, 64), + placement_helper(device, 0), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (384, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (512, 0, 0, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 0, 64), + placement_helper(device, 0), + ), + ( + (0, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (128, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (256, 0, 256, 64), + placement_helper(device, 3), + ), + ], + ] + sharder = TestQuantECSharder( + sharding_type=ShardingType.ROW_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + non_sharded_model._module_kjt_input[0], + per_param_sharding={ + "table_0": row_wise( + ([256, 128, 64, 64], device), + ), + "table_1": row_wise(([128, 128, 128, 128], device)), + "table_2": row_wise(([256, 128, 128, 0], device)), + "table_3": row_wise(([0, 128, 128, 256], device)), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan}) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + tbes = get_tbes_from_sharded_module(sharded_model._module_kjt_input[0]) + for tbe in tbes: + self.assertTrue(tbe.weight_initialized) + + path_device_lists = get_path_device_tuples(sharded_model) + + for path_device in path_device_lists: + assert device in path_device[1] + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_mix_tw_rw_sequence( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + local_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=3, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + + sharder = QuantEmbeddingCollectionSharder() + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + non_sharded_model._module_kjt_input[0], + per_param_sharding={ + "table_0": row_wise(), + "table_1": table_wise(rank=0), + "table_2": table_wise(rank=1), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan}) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + plan=plan, + expected_shards=None, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_mix_tw_rw_sequence_missing_feature_on_rank( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + local_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + + sharder = QuantEmbeddingCollectionSharder() + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + non_sharded_model._module_kjt_input[0], + per_param_sharding={ + "table_0": row_wise(), + "table_1": table_wise(rank=1), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan}) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + plan=plan, + expected_shards=None, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 2, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + uneven_shard_pattern=st.sampled_from( + [ + (512, 256, 128, 128), + (500, 256, 128, 128), + ] + ), + device=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_rw_uneven_sharding( + self, + weight_dtype: torch.dtype, + uneven_shard_pattern: Tuple[int, int, int, int], + device: str, + ) -> None: + num_embeddings, size0, size1, size2 = uneven_shard_pattern + size2 = min(size2, num_embeddings - size0 - size1) + emb_dim = 64 + local_size = 3 + world_size = 3 + batch_size = 4 + local_device = torch.device("cuda:0" if device == "cuda" else device) + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + num_features=1, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, size0, 64), + placement_helper(device, 0), + ), + ( + (size0, 0, size1, 64), + placement_helper(device, 1), + ), + ( + (size0 + size1, 0, size2, 64), + placement_helper(device, 2), + ), + ], + ] + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.ROW_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": row_wise(([size0, size1, size2], device)), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi=mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_rw_uneven_sharding_mutiple_table( + self, + weight_dtype: torch.dtype, + device: str, + ) -> None: + num_embeddings = 512 + emb_dim = 64 + local_size = 4 + world_size = 4 + batch_size = 1 + local_device = torch.device("cuda:0" if device == "cuda" else device) + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + num_features=4, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, 256, 64), + placement_helper(device, 0), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (384, 0, 64, 64), + placement_helper(device, 2), + ), + ( + (448, 0, 64, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 128, 64), + placement_helper(device, 0), + ), + ( + (128, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (384, 0, 128, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 256, 64), + placement_helper(device, 0), + ), + ( + (256, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (384, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (512, 0, 0, 64), + placement_helper(device, 3), + ), + ], + [ + ( + (0, 0, 0, 64), + placement_helper(device, 0), + ), + ( + (0, 0, 128, 64), + placement_helper(device, 1), + ), + ( + (128, 0, 128, 64), + placement_helper(device, 2), + ), + ( + (256, 0, 256, 64), + placement_helper(device, 3), + ), + ], + ] + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.ROW_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": row_wise( + ([256, 128, 64, 64], device), + ), + "table_1": row_wise(([128, 128, 128, 128], device)), + "table_2": row_wise(([256, 128, 128, 0], device)), + "table_3": row_wise(([0, 128, 128, 256], device)), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi=mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + device=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=4, deadline=None) + def test_mix_sharding_mutiple_table( + self, + weight_dtype: torch.dtype, + device: str, + ) -> None: + num_embeddings = 512 + emb_dim = 64 + local_size = 4 + world_size = 4 + batch_size = 1 + local_device = torch.device("cuda:0" if device == "cuda" else device) + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + num_features=4, + ) + + non_sharded_model = mi.quant_model + + sharder = QuantEmbeddingBagCollectionSharder() + + module_plan = construct_module_sharding_plan( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": row_wise( + ([256, 128, 64, 64], device), + ), + "table_1": row_wise(([128, 128, 128, 128], device)), + "table_2": column_wise(ranks=[0, 1]), + "table_3": table_wise(rank=0), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi=mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=None, # expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8]), + device_type=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=1, deadline=None) + def test_sharded_quant_fp_ebc_tw( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=mi.tables, + is_weighted=True, + device=mi.sparse_device, + ), + TimeGapPoolingCollectionModule( + feature_pow=1.0, + feature_min=-1.0, + feature_max=1.0, + device=mi.sparse_device, + ), + ) + ) + ) + model_inputs: List[ModelInput] = prep_inputs( + mi, world_size, batch_size, long_indices=False + ) + inputs = [] + for model_input in model_inputs: + kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) + kjt = kjt.to(local_device) + weights = torch.rand( + kjt._values.size(0), dtype=torch.float, device=local_device + ) + inputs.append( + ( + kjt._keys, + kjt._values, + weights, + kjt._lengths, + kjt._offsets, + ) + ) + + mi.model(*inputs[0]) + print(f"model:\n{mi.model}") + + mi.quant_model = quantize_fpebc( + module=mi.model, + inplace=False, + register_tbes=True, + quant_state_dict_split_scale_bias=False, + weight_dtype=weight_dtype, + ) + quant_model = mi.quant_model + print(f"quant_model:\n{quant_model}") + non_sharded_output = mi.quant_model(*inputs[0]) + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + sharder = QuantFeatureProcessedEmbeddingBagCollectionSharder() + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = _shard_modules( + module=quant_model, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[QuantFeatureProcessedEmbeddingBagCollectionSharder]`. + sharders=[sharder], + device=local_device, + plan=plan, + # pyre-ignore + env=ShardingEnv.from_local(world_size=mi.topology.world_size, rank=0), + ) + print(f"sharded_model:\n{sharded_model}") + for n, m in sharded_model.named_modules(): + print(f"sharded_model.MODULE[{n}]:{type(m)}") + + # Check that FP is registered as module + count_registered_fp: int = 0 + for _, m in sharded_model.named_modules(): + if isinstance(m, TimeGapPoolingCollectionModule): + count_registered_fp += 1 + + assert count_registered_fp == world_size + + sharded_output = sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace( + sharded_model, + leaf_modules=[ + "TimeGapPoolingCollectionModule", + "IntNBitTableBatchedEmbeddingBagsCodegen", + ], + ) + + # Check that FP was traced as a call_module + fp_call_module: int = 0 + for node in gm.graph.nodes: + if node.op == "call_module": + m = gm + for attr in node.target.split("."): + m = getattr(m, attr) + if isinstance(m, TimeGapPoolingCollectionModule): + fp_call_module += 1 + + assert fp_call_module == world_size + print(f"fx.graph:\n{gm.graph}") + + gm_script = torch.jit.script(gm) + print(f"gm_script:\n{gm_script}") + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8]), + device_type=st.sampled_from(["cpu", "cuda"]), + ) + @settings(max_examples=2, deadline=None) + def test_sharded_quant_mc_ec_rw( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=1, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ), + ManagedCollisionCollection( + managed_collision_modules={ + "table_0": MCHManagedCollisionModule( + zch_size=num_embeddings, + input_hash_size=4000, + device=mi.sparse_device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + }, + embedding_configs=mi.tables, + ), + ) + ) + ) + model_inputs: List[ModelInput] = prep_inputs( + mi, world_size, batch_size, long_indices=True + ) + inputs = [] + for model_input in model_inputs: + kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) + kjt = kjt.to(local_device) + weights = None + inputs.append( + ( + kjt._keys, + kjt._values, + weights, + kjt._lengths, + kjt._offsets, + ) + ) + + mi.model(*inputs[0]) + print(f"model:\n{mi.model}") + assert mi.model.training is True + mi.quant_model = quantize( + module=mi.model, + inplace=False, + register_tbes=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + quant_model = mi.quant_model + assert quant_model.training is False + non_sharded_output = mi.quant_model(*inputs[0]) + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + sharder = QuantEmbeddingCollectionSharder() + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=None, + plan=plan, + ) + + print(f"sharded_model:\n{sharded_model}") + for n, m in sharded_model.named_modules(): + print(f"sharded_model.MODULE[{n}]:{type(m)}") + + sharded_model.load_state_dict(quant_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + + assert_close(non_sharded_output[0], sharded_output[0]) + gm: torch.fx.GraphModule = symbolic_trace( + sharded_model, + leaf_modules=[ + "IntNBitTableBatchedEmbeddingBagsCodegen", + "ComputeJTDictToKJT", + ], + ) + + print(f"fx.graph:\n{gm.graph}") + gm_script = torch.jit.script(gm) + print(f"gm_script:\n{gm_script}") + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output[0], gm_script_output[0]) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + compute_device=st.sampled_from(["cuda", "cpu"]), + ) + @settings(max_examples=2, deadline=None) + def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: + # Simulate inference, take unsharded cpu model and shard on meta + # Use PositionWeightedModuleCollection, FP used in production + + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(compute_device) + + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + max_feature_lengths = {"feature_0": 20} + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=mi.tables, + is_weighted=True, + device=mi.sparse_device, + ), + PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + device=mi.sparse_device, + ), + ) + ) + ) + model_inputs: List[ModelInput] = prep_inputs( + mi, world_size, batch_size, long_indices=False, count=1 + ) + inputs = [] + kjt = model_inputs[0].idlist_features + assert isinstance(kjt, KeyedJaggedTensor) + kjt = kjt.to(local_device) + weights = torch.rand( + kjt._values.size(0), dtype=torch.float, device=local_device + ) + + inputs = [ + kjt._keys, + kjt._values, + weights, + kjt._lengths, + kjt._offsets, + ] + + mi.model(*inputs) + print(f"model:\n{mi.model}") + + mi.quant_model = quantize_fpebc( + module=mi.model, + inplace=False, + register_tbes=True, + quant_state_dict_split_scale_bias=False, + weight_dtype=torch.int8, + ) + quant_model = mi.quant_model + print(f"quant_model:\n{quant_model}") + + topology: Topology = Topology( + world_size=world_size, compute_device=compute_device + ) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + sharder = QuantFeatureProcessedEmbeddingBagCollectionSharder() + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = _shard_modules( + module=quant_model, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[QuantFeatureProcessedEmbeddingBagCollectionSharder]`. + sharders=[sharder], + # shard on meta to simulate device movement from cpu -> meta QFPEBC + device=torch.device("meta"), + plan=plan, + env=ShardingEnv.from_local(world_size=topology.world_size, rank=0), + ) + print(f"sharded_model:\n{sharded_model}") + for n, m in sharded_model.named_modules(): + print(f"sharded_model.MODULE[{n}]:{type(m)}") + + # Check that FP is registered as module + count_registered_fp: int = 0 + for _, m in sharded_model.named_modules(): + if isinstance(m, PositionWeightedModuleCollection): + count_registered_fp += 1 + + assert count_registered_fp == world_size + + # Move inputs to meta now that we shard on meta + for i, input in enumerate(inputs): + if isinstance(input, torch.Tensor): + inputs[i] = input.to(torch.device("meta")) + + # move dense params also to meta + sharded_model.to("meta") + sharded_model(*inputs) + # Don't care about the output since we are sharding on meta + + gm: torch.fx.GraphModule = symbolic_trace( + sharded_model, + leaf_modules=[ + "PositionWeightedModuleCollection", + "IntNBitTableBatchedEmbeddingBagsCodegen", + ], + ) + + # Check that FP was traced as a call_module + fp_call_module: int = 0 + for node in gm.graph.nodes: + if node.op == "call_module": + m = gm + for attr in node.target.split("."): + m = getattr(m, attr) + if isinstance(m, PositionWeightedModuleCollection): + fp_call_module += 1 + + assert fp_call_module == world_size + print(f"fx.graph:\n{gm.graph}") + + gm_script = torch.jit.script(gm) + print(f"gm_script:\n{gm_script}") + gm_script(*inputs) diff --git a/torchrec/distributed/tests/test_infer_utils.py b/torchrec/distributed/tests/test_infer_utils.py new file mode 100644 index 000000000..0fb374878 --- /dev/null +++ b/torchrec/distributed/tests/test_infer_utils.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import unittest +from typing import cast + +import torch + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ModuleSharder +from torchrec.distributed.infer_utils import ( + get_all_torchrec_modules, + get_tbe_specs_from_sharded_module, +) +from torchrec.distributed.quant_embeddingbag import ( + QuantEmbeddingBagCollection, + ShardedQuantEmbeddingBagCollection, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + table_wise, +) +from torchrec.distributed.test_utils.infer_utils import ( + quantize, + TestModelInfo, + TestQuantEBCSharder, + TestQuantECSharder, + TorchTypesModelInputWrapper, +) +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ShardingEnv, ShardingPlan, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) + + +class UtilsTest(unittest.TestCase): + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_get_tbe_specs_from_sqebc(self) -> None: + device = torch.device("cuda:0") + + num_features = 3 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 20, + embedding_dim=(i + 1) * 10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + model = torch.nn.Sequential( + EmbeddingBagCollection( + tables=tables, + device=device, + ) + ) + model.training = False + + quant_model = quantize( + model, + inplace=True, + output_type=torch.float, + quant_state_dict_split_scale_bias=True, + ) + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[f"table_{i}" for i in range(num_features)], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + quant_model[0], + per_param_sharding={ + "table_0": table_wise(rank=1), + "table_1": table_wise(rank=0), + "table_2": table_wise(rank=0), + }, + # pyre-ignore + sharder=sharder, + local_size=2, + world_size=2, + ) + + plan = ShardingPlan(plan={"": module_plan}) + + sharded_model = _shard_modules( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + module=quant_model[0], + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[TestQuantEBCSharder]`. + sharders=[sharder], + device=device, + plan=plan, + env=ShardingEnv.from_local(world_size=2, rank=0), + ) + + specs = get_tbe_specs_from_sharded_module(sharded_model) + + expected_specs = [ + ("table_1", 40, 20, "int8", "EmbeddingLocation.DEVICE"), + ("table_2", 60, 30, "int8", "EmbeddingLocation.DEVICE"), + ("table_0", 20, 10, "int8", "EmbeddingLocation.DEVICE"), + ] + + self.assertEqual(specs, expected_specs) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_get_tbe_specs_from_sqec(self) -> None: + device = torch.device("cuda:0") + + num_features = 3 + + tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 20, + embedding_dim=10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + model = torch.nn.Sequential( + EmbeddingCollection( + tables=tables, + device=device, + ) + ) + model.training = False + + quant_model = quantize( + model, + inplace=True, + output_type=torch.float, + quant_state_dict_split_scale_bias=True, + ) + + sharder = TestQuantECSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[f"table_{i}" for i in range(num_features)], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + quant_model[0], + per_param_sharding={ + "table_0": table_wise(rank=1), + "table_1": table_wise(rank=0), + "table_2": table_wise(rank=0), + }, + # pyre-ignore + sharder=sharder, + local_size=2, + world_size=2, + ) + + plan = ShardingPlan(plan={"": module_plan}) + + sharded_model = _shard_modules( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + module=quant_model[0], + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[TestQuantECSharder]`. + sharders=[sharder], + device=device, + plan=plan, + env=ShardingEnv.from_local(world_size=2, rank=0), + ) + + specs = get_tbe_specs_from_sharded_module(sharded_model) + + expected_specs = [ + ("table_1", 40, 10, "int8", "EmbeddingLocation.DEVICE"), + ("table_2", 60, 10, "int8", "EmbeddingLocation.DEVICE"), + ("table_0", 20, 10, "int8", "EmbeddingLocation.DEVICE"), + ] + + self.assertEqual(specs, expected_specs) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_get_all_torchrec_modules_for_single_module(self) -> None: + device = torch.device("cuda:0") + + num_features = 2 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 20, + embedding_dim=(i + 1) * 10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + model = torch.nn.Sequential( + EmbeddingBagCollection( + tables=tables, + device=device, + ) + ) + model.training = False + + all_trec_mdoules = get_all_torchrec_modules(model) + + quant_model = quantize( + model, + inplace=True, + output_type=torch.float, + quant_state_dict_split_scale_bias=True, + ) + + all_trec_mdoules = get_all_torchrec_modules(quant_model) + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[f"table_{i}" for i in range(num_features)], + ) + + module_plan = construct_module_sharding_plan( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + quant_model[0], + per_param_sharding={ + "table_0": table_wise(rank=0), + "table_1": table_wise(rank=1), + }, + # pyre-ignore + sharder=sharder, + local_size=2, + world_size=2, + ) + + plan = ShardingPlan(plan={"": module_plan}) + + sharded_model = _shard_modules( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + module=quant_model[0], + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[TestQuantEBCSharder]`. + sharders=[sharder], + device=device, + plan=plan, + env=ShardingEnv.from_local(world_size=2, rank=0), + ) + + all_trec_mdoules = get_all_torchrec_modules(sharded_model) + self.assertDictEqual(all_trec_mdoules, {"": sharded_model}) + + all_trec_modules = get_all_torchrec_modules( + sharded_model, [QuantEmbeddingBagCollection] + ) + self.assertEqual(all_trec_modules, {}) + self.assertDictEqual + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_get_all_torchrec_modules_for_test_SparseNN_model(self) -> None: + local_device = torch.device("cuda:0") + model_info = TestModelInfo( + sparse_device=local_device, + dense_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=2, + ) + + model_info.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=512, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(model_info.num_features) + ] + model_info.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(model_info.num_weighted_features) + ] + model_info.model = TorchTypesModelInputWrapper( + TestSparseNN( + tables=model_info.tables, + weighted_tables=model_info.weighted_tables, + num_float_features=model_info.num_float_features, + dense_device=model_info.dense_device, + sparse_device=model_info.sparse_device, + ) + ) + + model_info.model.training = False + model_info.quant_model = quantize( + model_info.model, + inplace=True, + quant_state_dict_split_scale_bias=True, + ) + + model_info.sharders = [ + cast( + ModuleSharder[torch.nn.Module], + TestQuantEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in model_info.tables], + ), + ), + ] + + sharded_model = _shard_modules( + module=model_info.quant_model, + sharders=model_info.sharders, + device=model_info.sparse_device, + env=ShardingEnv.from_local(world_size=2, rank=0), + ) + + all_trec_mdoules = get_all_torchrec_modules(sharded_model) + + expected_all_trec_modules = { + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + "_module.sparse.ebc": sharded_model._module.sparse.ebc, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + "_module.sparse.weighted_ebc": sharded_model._module.sparse.weighted_ebc, + } + + self.assertDictEqual( + all_trec_mdoules, + expected_all_trec_modules, + ) + + all_trec_mdoules = get_all_torchrec_modules( + sharded_model, [ShardedQuantEmbeddingBagCollection] + ) + + self.assertDictEqual( + all_trec_mdoules, + { + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse`. + "_module.sparse.ebc": sharded_model._module.sparse.ebc, + }, + ) diff --git a/torchrec/distributed/tests/test_init_parameters.py b/torchrec/distributed/tests/test_init_parameters.py new file mode 100644 index 000000000..515c45a36 --- /dev/null +++ b/torchrec/distributed/tests/test_init_parameters.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast, List, Optional, Union + +import torch +from hypothesis import given, settings, strategies as st, Verbosity +from torch import nn +from torch.distributed._tensor import DTensor +from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + data_parallel, + ParameterShardingGenerator, + row_wise, + table_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardedTensor, + ShardingEnv, + ShardingPlan, + ShardingType, +) +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.test_utils import skip_if_asan_class + + +def initialize_and_test_parameters( + rank: int, + world_size: int, + backend: str, + embedding_tables: Union[EmbeddingCollection, EmbeddingBagCollection], + sharding_type: str, + sharders: List[ModuleSharder[nn.Module]], + table_name: str, + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + + module_sharding_plan = construct_module_sharding_plan( + embedding_tables, + per_param_sharding={ + table_name: _select_sharding_type(sharding_type), + }, + local_size=ctx.local_size, + world_size=ctx.world_size, + device_type=ctx.device.type, + ) + + model = DistributedModelParallel( + module=embedding_tables, + plan=ShardingPlan({"": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=sharders, + device=ctx.device, + ) + + key = ( + f"embeddings.{table_name}.weight" + if isinstance(embedding_tables, EmbeddingCollection) + else f"embedding_bags.{table_name}.weight" + ) + + if isinstance(model.state_dict()[key], DTensor): + if ctx.rank == 0: + gathered_tensor = torch.empty(model.state_dict()[key].size()) + else: + gathered_tensor = None + + gathered_tensor = model.state_dict()[key].full_tensor() + if ctx.rank == 0: + torch.testing.assert_close( + gathered_tensor, + embedding_tables.state_dict()[key], + ) + elif isinstance(model.state_dict()[key], ShardedTensor): + if ctx.rank == 0: + gathered_tensor = torch.empty_like(embedding_tables.state_dict()[key]) + else: + gathered_tensor = None + + model.state_dict()[key].gather(dst=0, out=gathered_tensor) + + if ctx.rank == 0: + torch.testing.assert_close( + gathered_tensor, + embedding_tables.state_dict()[key], + ) + elif isinstance(model.state_dict()[key], torch.Tensor): + torch.testing.assert_close( + embedding_tables.state_dict()[key].cpu(), + model.state_dict()[key].cpu(), + ) + else: + raise AssertionError( + f"Model state dict contains unsupported type for key: {key}" + ) + + +def _select_sharding_type(sharding_type: str) -> ParameterShardingGenerator: + if sharding_type == "table_wise": + return table_wise(rank=0) + elif sharding_type == "column_wise": + return column_wise(ranks=[0, 1]) + elif sharding_type == "row_wise": + return row_wise() + elif sharding_type == "data_parallel": + return data_parallel() + else: + raise AssertionError(f"Invalid sharding type specified: {sharding_type}") + + +@skip_if_asan_class +class ParameterInitializationTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.DATA_PARALLEL.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_WISE.value, + ] + ) + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_initialize_parameters_ec(self, sharding_type: str) -> None: + world_size = 2 + backend = "nccl" + table_name = "free_parameters" + + # Initialize embedding table on non-meta device, in this case cuda:0 + embedding_tables = EmbeddingCollection( + device=torch.device("cuda:0"), + tables=[ + EmbeddingConfig( + name=table_name, + embedding_dim=64, + num_embeddings=10, + data_type=DataType.FP32, + ) + ], + ) + + embedding_tables.load_state_dict( + {f"embeddings.{table_name}.weight": torch.randn(10, 64)} + ) + + self._run_multi_process_test( + callable=initialize_and_test_parameters, + embedding_tables=embedding_tables, + sharding_type=sharding_type, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder()) + ], + world_size=world_size, + backend=backend, + table_name=table_name, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.DATA_PARALLEL.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.TABLE_WISE.value, + ] + ) + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_initialize_parameters_ebc(self, sharding_type: str) -> None: + world_size = 2 + backend = "nccl" + table_name = "free_parameters" + + # Initialize embedding bag on non-meta device, in this case cuda:0 + embedding_tables = EmbeddingBagCollection( + device=torch.device("cuda:0"), + tables=[ + EmbeddingBagConfig( + name=table_name, + embedding_dim=64, + num_embeddings=10, + data_type=DataType.FP32, + ) + ], + ) + + embedding_tables.load_state_dict( + {f"embedding_bags.{table_name}.weight": torch.randn(10, 64)} + ) + + self._run_multi_process_test( + callable=initialize_and_test_parameters, + embedding_tables=embedding_tables, + sharding_type=sharding_type, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + world_size=world_size, + backend=backend, + table_name=table_name, + ) diff --git a/torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py b/torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py new file mode 100644 index 000000000..d5a41c76d --- /dev/null +++ b/torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py @@ -0,0 +1,835 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from typing import cast, Dict, List + +import torch +from hypothesis import given, settings, strategies as st +from torchrec.distributed.keyed_jagged_tensor_pool import ( + KeyedJaggedTensorPoolSharder, + ShardedInferenceKeyedJaggedTensorPool, + ShardedKeyedJaggedTensorPool, +) +from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ( + ModuleSharder, + ObjectPoolShardingPlan, + ObjectPoolShardingType, + ShardingEnv, + ShardingPlan, +) +from torchrec.modules.keyed_jagged_tensor_pool import KeyedJaggedTensorPool +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class TestShardedKeyedJaggedTensorPool(MultiProcessTestBase): + @staticmethod + def _test_sharded_keyed_jagged_tensor_pool( + rank: int, + world_size: int, + backend: str, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_plan: ObjectPoolShardingPlan, + input_per_rank: List[torch.Tensor], + enable_uvm: bool = False, + ) -> None: + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + input_per_rank = [id.to(ctx.device) for id in input_per_rank] + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=is_weighted, + device=torch.device("meta"), + enable_uvm=enable_uvm, + ) + + # pyre-ignore + sharded_keyed_jagged_tensor_pool: ( + ShardedKeyedJaggedTensorPool + ) = KeyedJaggedTensorPoolSharder().shard( + keyed_jagged_tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but + # got `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + # rank 0 + # 0 1 + # "f1" [1] [3,3] + # "f2" [11] [13,13,13] + + # rank 1 + # 0 1 + # "f1" [2,2] [4] + # "f2" [12,12] [14,14,14,14] + values = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + ( + [1, 3, 3, 11, 13, 13, 13] + if ctx.rank == 0 + else [2, 2, 4, 12, 12, 14, 14, 14, 14] + ), + dtype=values_dtype, + device=ctx.device, + ), + lengths=torch.tensor( + [1, 2, 1, 3] if ctx.rank == 0 else [2, 1, 2, 4], + dtype=torch.int, + device=ctx.device, + ), + ) + + sharded_keyed_jagged_tensor_pool.update( + ids=torch.tensor( + [2, 0] if ctx.rank == 0 else [1, 3], + dtype=torch.int, + device=ctx.device, + ), + values=values, + ) + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + kjt = sharded_keyed_jagged_tensor_pool.lookup(input_per_rank[ctx.rank]) + + # expected values + # rank 0: KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + # rank 1: KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + ( + [1, 3, 3, 11, 13, 13, 13] + if ctx.rank == 0 + else [2, 2, 4, 3, 3, 1, 12, 12, 14, 14, 14, 14, 13, 13, 13, 11] + ), + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 2, 1, 3] if ctx.rank == 0 else [2, 1, 2, 1, 2, 4, 3, 1], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-ignore + @given( + enable_uvm=st.booleans(), + values_dtype=st.sampled_from([torch.int32, torch.int64]), + ) + @settings(max_examples=4, deadline=None) + def test_sharded_keyed_jagged_tensor_pool_rw( + self, enable_uvm: bool, values_dtype: torch.dtype + ) -> None: + input_per_rank = [ + torch.tensor([2, 0], dtype=torch.int), + torch.tensor([1, 3, 0, 2], dtype=torch.int), + ] + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + + self._run_multi_process_test( + callable=self._test_sharded_keyed_jagged_tensor_pool, + world_size=2, + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=False, + input_per_rank=input_per_rank, + sharding_plan=ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ), + backend="nccl", + enable_uvm=enable_uvm, + ) + + @staticmethod + def _test_input_permute( + rank: int, + world_size: int, + backend: str, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_plan: ObjectPoolShardingPlan, + input_per_rank: List[torch.Tensor], + ) -> None: + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + input_per_rank = [id.to(ctx.device) for id in input_per_rank] + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=is_weighted, + device=torch.device("meta"), + ) + + # pyre-ignore + sharded_keyed_jagged_tensor_pool: ( + ShardedKeyedJaggedTensorPool + ) = KeyedJaggedTensorPoolSharder().shard( + keyed_jagged_tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but + # got `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + sharded_keyed_jagged_tensor_pool.update( + ids=torch.tensor( + [2, 0] if ctx.rank == 0 else [1, 3], + dtype=torch.int, + device=ctx.device, + ), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f3", "f2", "f1"], + values=torch.tensor( + ( + [21, 11, 13, 13, 13, 1, 3, 3] + if ctx.rank == 0 + else [22, 22, 24, 12, 12, 14, 14, 14, 14, 2, 2, 4] + ), + dtype=values_dtype, + device=ctx.device, + ), + lengths=torch.tensor( + [1, 0, 1, 3, 1, 2] if ctx.rank == 0 else [2, 1, 2, 4, 2, 1], + dtype=torch.int, + device=ctx.device, + ), + ), + ) + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + kjt = sharded_keyed_jagged_tensor_pool(input_per_rank[ctx.rank]) + + # expected values + # rank 0: KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + # rank 1: KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + ( + [1, 3, 3, 11, 13, 13, 13] + if ctx.rank == 0 + else [2, 2, 4, 3, 3, 1, 12, 12, 14, 14, 14, 14, 13, 13, 13, 11] + ), + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 2, 1, 3] if ctx.rank == 0 else [2, 1, 2, 1, 2, 4, 3, 1], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + def test_input_permute( + self, + ) -> None: + input_per_rank = [ + torch.tensor([2, 0], dtype=torch.int), + torch.tensor([1, 3, 0, 2], dtype=torch.int), + ] + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + + self._run_multi_process_test( + callable=self._test_input_permute, + world_size=2, + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=torch.int64, + is_weighted=False, + input_per_rank=input_per_rank, + sharding_plan=ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ), + backend="nccl", + ) + + @staticmethod + def _test_sharded_KJT_pool_input_conflict( + rank: int, + world_size: int, + backend: str, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_plan: ObjectPoolShardingPlan, + input_per_rank: List[torch.Tensor], + ) -> None: + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + input_per_rank = [id.to(ctx.device) for id in input_per_rank] + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=is_weighted, + device=torch.device("meta"), + ) + + # pyre-ignore + sharded_keyed_jagged_tensor_pool: ( + ShardedKeyedJaggedTensorPool + ) = KeyedJaggedTensorPoolSharder().shard( + keyed_jagged_tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but + # got `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + # rank 0 input: + # ids f1 f2 + # 2 1 11 + # 1 3, 3 13, 13, 13 + + # rank 1 input: + # ids f1 f2 + # 1 2, 2 12, 12 + # 3 4 14, 14, 14, 14 + + sharded_keyed_jagged_tensor_pool.update( + ids=torch.tensor( + [2, 1] if ctx.rank == 0 else [1, 3], + dtype=torch.int, + device=ctx.device, + ), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + ( + [1, 3, 3, 11, 13, 13, 13] + if ctx.rank == 0 + else [2, 2, 4, 12, 12, 14, 14, 14, 14] + ), + dtype=values_dtype, + device=ctx.device, + ), + lengths=torch.tensor( + [1, 2, 1, 3] if ctx.rank == 0 else [2, 1, 2, 4], + dtype=torch.int, + device=ctx.device, + ), + ), + ) + + kjt = sharded_keyed_jagged_tensor_pool(input_per_rank[ctx.rank]) + # expected values + # rank 0: KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + # rank 1: KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + ( + [1, 11] + if ctx.rank == 0 + else [2, 2, 4, 1, 12, 12, 14, 14, 14, 14, 11] + ), + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 0, 1, 0] if ctx.rank == 0 else [2, 1, 0, 1, 2, 4, 0, 1], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + def test_sharded_KJT_pool_input_conflict( + self, + ) -> None: + input_per_rank = [ + torch.tensor([2, 0], dtype=torch.int), + torch.tensor([1, 3, 0, 2], dtype=torch.int), + ] + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + + self._run_multi_process_test( + callable=self._test_sharded_KJT_pool_input_conflict, + world_size=2, + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=torch.int64, + is_weighted=False, + input_per_rank=input_per_rank, + sharding_plan=ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ), + backend="nccl", + ) + + @staticmethod + def _test_sharded_KJT_pool_input_empty( + rank: int, + world_size: int, + backend: str, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_plan: ObjectPoolShardingPlan, + input_per_rank: List[torch.Tensor], + ) -> None: + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + input_per_rank = [id.to(ctx.device) for id in input_per_rank] + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=is_weighted, + device=torch.device("meta"), + ) + + # pyre-ignore + sharded_keyed_jagged_tensor_pool: ( + ShardedKeyedJaggedTensorPool + ) = KeyedJaggedTensorPoolSharder().shard( + keyed_jagged_tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but + # got `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + # rank 0 input: + # ids f1 f2 + # 2 1 11 + # 1 3, 3 13, 13, 13 + + # rank 1 input: + # ids f1 f2 + # 1 2, 2 12, 12 + # 3 4 14, 14, 14, 14 + + sharded_keyed_jagged_tensor_pool.update( + ids=torch.tensor( + [2, 1] if ctx.rank == 0 else [1, 3], + dtype=torch.int, + device=ctx.device, + ), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + ( + [1, 3, 3, 11, 13, 13, 13] + if ctx.rank == 0 + else [2, 2, 4, 12, 12, 14, 14, 14, 14] + ), + dtype=values_dtype, + device=ctx.device, + ), + lengths=torch.tensor( + [1, 2, 1, 3] if ctx.rank == 0 else [2, 1, 2, 4], + dtype=torch.int, + device=ctx.device, + ), + ), + ) + + kjt = sharded_keyed_jagged_tensor_pool(input_per_rank[ctx.rank]) + # expected values + # rank 0: KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + # rank 1: KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [1, 11] if ctx.rank == 0 else [], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 0, 1, 0] if ctx.rank == 0 else [], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + def test_sharded_KJT_pool_input_empty(self) -> None: + input_per_rank = [ + torch.tensor([2, 0], dtype=torch.int), + torch.tensor([], dtype=torch.int), + ] + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + self._run_multi_process_test( + callable=self._test_sharded_KJT_pool_input_empty, + world_size=2, + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=torch.int64, + is_weighted=False, + input_per_rank=input_per_rank, + sharding_plan=ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ), + backend="nccl", + ) + + @staticmethod + def _test_sharded_keyed_jagged_tensor_pool_replicated_rw( + rank: int, + world_size: int, + local_world_size: int, + backend: str, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype, + is_weighted: bool, + sharding_plan: ObjectPoolShardingPlan, + ) -> None: + with MultiProcessContext( + rank, world_size, backend, local_size=local_world_size + ) as ctx: + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + is_weighted=is_weighted, + device=torch.device("meta"), + ) + + # pyre-ignore + sharded_keyed_jagged_tensor_pool: ( + ShardedKeyedJaggedTensorPool + ) = KeyedJaggedTensorPoolSharder().shard( + keyed_jagged_tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but + # got `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + ids = [[1], [0], [2], [3]] + + values_and_lengths = [ + ([2, 2, 12, 12], [2, 2]), + ([3, 3, 13, 13, 13], [2, 3]), + ([1, 11], [1, 1]), + ([4, 14, 14, 14, 14], [1, 4]), + ] + + sharded_keyed_jagged_tensor_pool.update( + ids=torch.tensor( + ids[ctx.rank], + dtype=torch.int, + device=ctx.device, + ), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + values_and_lengths[ctx.rank][0], + dtype=values_dtype, + device=ctx.device, + ), + lengths=torch.tensor( + values_and_lengths[ctx.rank][1], + dtype=torch.int, + device=ctx.device, + ), + ), + ) + + lookup_per_rank = [[0, 1, 2, 3], [0, 2], [3, 1], [0]] + + kjt = sharded_keyed_jagged_tensor_pool.lookup( + torch.tensor( + lookup_per_rank[ctx.rank], device=ctx.device, dtype=torch.int32 + ) + ).wait() + + # expected values + # rank 0: + # kjt KeyedJaggedTensor({ + # "f1": [[3, 3], [2, 2], [1], [4]], + # "f2": [[13, 13, 13], [12, 12], [11], [14, 14, 14, 14]] + # }) + # rank 1: + # kjt KeyedJaggedTensor({ + # "f1": [[3, 3], [1]], + # "f2": [[13, 13, 13], [11]] + # }) + # rank 2: + # kjt KeyedJaggedTensor({ + # "f1": [[4], [2, 2]], + # "f2": [[14, 14, 14, 14], [12, 12]] + # }) + # rank 3: + # kjt KeyedJaggedTensor({ + # "f1": [[3, 3]], + # "f2": [[13, 13, 13]] + # }) + + expected_values_and_lengths = [ + ( + [3, 3, 2, 2, 1, 4, 13, 13, 13, 12, 12, 11, 14, 14, 14, 14], + [2, 2, 1, 1, 3, 2, 1, 4], + ), + ([3, 3, 1, 13, 13, 13, 11], [2, 1, 3, 1]), + ([4, 2, 2, 14, 14, 14, 14, 12, 12], [1, 2, 4, 2]), + ([3, 3, 13, 13, 13], [2, 3]), + ] + + torch.testing.assert_close( + kjt.values(), + torch.tensor( + expected_values_and_lengths[ctx.rank][0], + dtype=kjt.values().dtype, + device=kjt.values().device, + ), + ) + + torch.testing.assert_close( + kjt.lengths(), + torch.tensor( + expected_values_and_lengths[ctx.rank][1], + dtype=kjt.lengths().dtype, + device=kjt.lengths().device, + ), + ) + + assert list(sharded_keyed_jagged_tensor_pool.state_dict().keys()) == [ + "values", + "key_lengths", + ] + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + def test_sharded_keyed_jagged_tensor_pool_replicated_rw( + self, + ) -> None: + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + + self._run_multi_process_test( + callable=self._test_sharded_keyed_jagged_tensor_pool_replicated_rw, + world_size=4, + local_world_size=4, + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=torch.int64, + is_weighted=False, + sharding_plan=ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.REPLICATED_ROW_WISE + ), + backend="nccl", + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + def test_sharded_kjt_pool_inference(self) -> None: + world_size = 2 + pool_size = 4 + device = torch.device("cpu") + cuda_device = torch.device("cuda:0") + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + kjt_pool_orig = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths={"f1": 2, "f2": 4}, + values_dtype=torch.int, + is_weighted=False, + device=torch.device("cpu"), + ) + kjt_pool_orig.update( + ids=torch.tensor([0, 1, 2, 3], dtype=torch.int, device=device), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + [3, 3, 2, 2, 1, 4, 13, 13, 13, 12, 12, 11, 14, 14, 14, 14], + dtype=torch.int, + device=torch.device("cpu"), + ), + lengths=torch.tensor( + [2, 2, 1, 1, 3, 2, 1, 4], + dtype=torch.int, + device=torch.device("cpu"), + ), + ), + ) + + sharded_inference_kjt_pool = _shard_modules( + kjt_pool_orig, + plan=ShardingPlan( + plan={ + "": ObjectPoolShardingPlan( + ObjectPoolShardingType.ROW_WISE, inference=True + ), + } + ), + device=cuda_device, + env=ShardingEnv.from_local(world_size=world_size, rank=0), + sharders=[ + cast(ModuleSharder[torch.nn.Module], KeyedJaggedTensorPoolSharder()) + ], + ) + self.assertIsInstance( + sharded_inference_kjt_pool, ShardedInferenceKeyedJaggedTensorPool + ) + + self.assertEqual(sharded_inference_kjt_pool.dtype, torch.int) + + from torchrec.fx.tracer import symbolic_trace + + sharded_inference_kjt_pool_gm: torch.fx.GraphModule = symbolic_trace( + sharded_inference_kjt_pool + ) + sharded_inference_kjt_pool_gm_script = torch.jit.script( + sharded_inference_kjt_pool_gm + ) # noqa + + input_cases = [[0, 1, 2, 3], [0, 2, 1, 3]] + for input in input_cases: + input = torch.tensor(input, dtype=torch.int64) + ref = kjt_pool_orig.lookup(input) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + val = sharded_inference_kjt_pool.lookup(input.to(cuda_device)) + + torch.testing.assert_close(ref.values().cpu(), val.values().cpu()) + torch.testing.assert_close(ref.length_per_key(), val.length_per_key()) + + val_gm_script = sharded_inference_kjt_pool_gm_script(input.to(cuda_device)) + torch.testing.assert_close(ref.values().cpu(), val_gm_script.values().cpu()) + torch.testing.assert_close( + ref.length_per_key(), val_gm_script.length_per_key() + ) + + assert hasattr(sharded_inference_kjt_pool_gm_script, "_local_kjt_pool_shards") + assert hasattr(sharded_inference_kjt_pool_gm_script._local_kjt_pool_shards, "0") + assert hasattr(sharded_inference_kjt_pool_gm_script._local_kjt_pool_shards, "1") diff --git a/torchrec/distributed/tests/test_lazy_awaitable.py b/torchrec/distributed/tests/test_lazy_awaitable.py index 1b176fa4a..2237e4887 100644 --- a/torchrec/distributed/tests/test_lazy_awaitable.py +++ b/torchrec/distributed/tests/test_lazy_awaitable.py @@ -5,12 +5,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import Dict import torch import torch.fx -from torchrec.distributed.types import LazyAwaitable +from torchrec.distributed.types import LazyAwaitable, LazyGetItemMixin class NeedWait(LazyAwaitable[torch.Tensor]): @@ -252,3 +254,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4))) tempFile.close() + + def test_lazy_getitem_mixin(self) -> None: + class LazyGetItemAwaitable( + LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[Dict[str, torch.Tensor]] + ): + def __init__(self, actual_value: Dict[str, torch.Tensor]): + super().__init__() + self.actual_value = actual_value + + def _wait_impl(self) -> Dict[str, torch.Tensor]: + for v in self.actual_value.values(): + v *= 3 + return self.actual_value + + actual_value = {"foo": torch.tensor(1), "bar": torch.tensor(2)} + a = LazyGetItemAwaitable(actual_value) + lazy_foo = a["foo"] + lazy_bar = a["bar"] + # The returned value should be lazy + self.assertIsInstance(lazy_foo, LazyAwaitable) + self.assertIsInstance(lazy_bar, LazyAwaitable) + + # Our lazy values should not have been waited yet + self.assertIsNone(lazy_foo._result) + self.assertIsNone(lazy_bar._result) + self.assertIsNone(a._result) + + # The use of a torch op should trigger exactly one wait on the parent object. + result = torch.add(lazy_foo, lazy_bar) + self.assertEqual(result, torch.tensor(1 * 3 + 2 * 3)) diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py new file mode 100644 index 000000000..64c3ca14e --- /dev/null +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -0,0 +1,1283 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from hypothesis import given, settings, strategies as st +from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.mc_embedding import ( + ManagedCollisionEmbeddingCollectionSharder, + ShardedManagedCollisionEmbeddingCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + EmbeddingCollectionSharder, + row_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardedTensor, + ShardingEnv, + ShardingPlan, +) +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + return_remapped: bool = False, + input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + mc_modules["table_0"] = MCHManagedCollisionModule( + zch_size=(tables[0].num_embeddings), + input_hash_size=input_hash_size, + device=device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + mc_modules["table_1"] = MCHManagedCollisionModule( + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[torch.Tensor, Optional[Dict[str, JaggedTensor]]]: + ec_out, remapped_ids_out = self._mc_ec(kjt) + pred = torch.cat( + [ec_out[key].values() for key in ["feature_0", "feature_1"]], + dim=0, + ) + loss = pred.mean() + return loss, remapped_ids_out + + +def _test_sharding_and_remapping( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + initial_state_per_rank: List[Dict[str, torch.Tensor]], + final_state_per_rank: List[Dict[str, torch.Tensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + input_hash_size: int = 4000, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + input_hash_size=input_hash_size, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ec, ShardedManagedCollisionEmbeddingCollection + ) + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection, + ShardedEmbeddingCollection, + ) + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection._has_uninitialized_input_dist + is False + ) + assert ( + not hasattr( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection, + "_input_dists", + ) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + or len(sharded_sparse_arch._mc_ec._embedding_collection._input_dists) == 0 + ) + + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ec._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ec._managed_collision_collection._use_index_dedup + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + == sharded_sparse_arch._mc_ec._embedding_collection._use_index_dedup + ) + + initial_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in initial_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in initial_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, initial_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}" + + sharded_sparse_arch.load_state_dict(initial_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss1.backward() + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + loss2.backward() + + final_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in final_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in final_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, final_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}" + + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + # TODO: validate embedding rows, and eviction + + +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + initial_state_per_rank: List[Dict[str, torch.Tensor]], + final_state_per_rank: List[Dict[str, torch.Tensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = True, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + input_hash_size=input_hash_size, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + initial_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in initial_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in initial_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, initial_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}" + + sharded_sparse_arch.load_state_dict(initial_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + final_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in final_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in final_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, final_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}" + + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + +def _test_sharding_and_resharding( # noqa C901 + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + initial_state_per_rank: List[Dict[str, torch.Tensor]], + final_state_per_rank: List[Dict[str, torch.Tensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ec, ShardedManagedCollisionEmbeddingCollection + ) + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection, + ShardedEmbeddingCollection, + ) + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection._has_uninitialized_input_dist + is False + ) + assert ( + not hasattr( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_embedding_collection`. + sharded_sparse_arch._mc_ec._embedding_collection, + "_input_dists", + ) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + or len(sharded_sparse_arch._mc_ec._embedding_collection._input_dists) == 0 + ) + + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ec._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss1.backward() + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in kjt_input.keys(): + for i, kjt_out in enumerate(kjt_out_per_iter[:2]): # first two iterations + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + state_dict = sharded_sparse_arch.state_dict() + cpu_state_dict = {} + for key, tensor in state_dict.items(): + if isinstance(tensor, ShardedTensor): + tensor = tensor.local_shards()[0].tensor + cpu_state_dict[key] = tensor.to("cpu") + gather_list = [None, None] if ctx.rank == 0 else None + torch.distributed.gather_object(cpu_state_dict, gather_list) + + if rank == 0: + with MultiProcessContext(rank, 1, backend, 1) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings[ + "table_0" + ].weight, + sparse_arch._mc_ec._embedding_collection.embeddings[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=1, + world_size=1, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + state_dict = sharded_sparse_arch.state_dict() + + for key in state_dict.keys(): + if isinstance(state_dict[key], ShardedTensor): + replacement_tensor = torch.cat( + # pyre-ignore [16] + [gather_list[0][key], gather_list[1][key]], + dim=0, + ).to(ctx.device) + state_dict[key].local_shards()[0].tensor.copy_(replacement_tensor) + else: + state_dict[key] = gather_list[0][key].to(ctx.device) + + sharded_sparse_arch.load_state_dict(state_dict) + loss3, remapped_ids3 = sharded_sparse_arch(kjt_input) + final_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in final_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in final_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, final_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}" + + remapped_ids = [remapped_ids3] + for key in kjt_input.keys(): + for i, kjt_out in enumerate(kjt_out_per_iter[-1:]): # last iteration + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + +def _test_sharding_dedup( # noqa C901 + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + dedup_sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + input_hash_size: int = 4000, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + return_remapped: bool = True + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + input_hash_size=input_hash_size, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + dedup_sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[dedup_sharder], + device=ctx.device, + ) + + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ec._managed_collision_collection._use_index_dedup + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + == sharded_sparse_arch._mc_ec._embedding_collection._use_index_dedup + ) + + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ec._managed_collision_collection._use_index_dedup + is False + ) + + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + dedup_sharded_sparse_arch._mc_ec._managed_collision_collection._use_index_dedup + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_collection`. + == dedup_sharded_sparse_arch._mc_ec._embedding_collection._use_index_dedup + ) + + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + dedup_sharded_sparse_arch._mc_ec._managed_collision_collection._use_index_dedup + is True + ) + + # sync state_dict() + state_dict = sharded_sparse_arch.state_dict() + dedup_state_dict = dedup_sharded_sparse_arch.state_dict() + for key, sharded_tensor in state_dict.items(): + if isinstance(sharded_tensor, ShardedTensor): + dedup_state_dict[key].local_shards()[ + 0 + ].tensor = sharded_tensor.local_shards()[0].tensor.clone() + dedup_state_dict[key] = sharded_tensor.clone() + dedup_sharded_sparse_arch.load_state_dict(dedup_state_dict) + + loss1, remapped_1 = sharded_sparse_arch(kjt_input) + loss1.backward() + dedup_loss1, dedup_remapped_1 = dedup_sharded_sparse_arch(kjt_input) + dedup_loss1.backward() + + assert torch.allclose(loss1, dedup_loss1) + # deduping is not being used right now + # assert torch.allclose(remapped_1.values(), dedup_remapped_1.values()) + # assert torch.allclose(remapped_1.lengths(), dedup_remapped_1.lengths()) + + +@skip_if_asan_class +class ShardedMCEmbeddingCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_sharding_zch_mc_ec_reshard(self, backend: str) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.empty(), + ] + ) + + max_int = torch.iinfo(torch.int64).max + + final_state_per_rank = [ + { + "table_0._mch_sorted_raw_ids": torch.LongTensor( + [1000, 1001, 1002, 1004, 2000] + [max_int] * (16 - 5) + ), + "table_1._mch_sorted_raw_ids": torch.LongTensor( + [2000, 2001, 2002, 2004] + [max_int] * (32 - 4) + ), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [3, 4, 5, 6, 14, 0, 1, 2, 7, 8, 9, 10, 11, 12, 13, 15], + ), + "table_1._mch_remapped_ids_mapping": torch.LongTensor( + [ + 27, + 29, + 28, + 30, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 31, + ], + ), + }, + ] + + self._run_multi_process_test( + callable=_test_sharding_and_resharding, + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + initial_state_per_rank=None, + final_state_per_rank=final_state_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_sharding_zch_mc_ec_remap(self, backend: str) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 2, + 2, + 2, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + initial_state_per_rank = [ + { + "table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_remapped_ids_mapping": torch.arange( + start=8, end=16, dtype=torch.int64 + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + start=16, end=32, dtype=torch.int64 + ), + }, + ] + max_int = torch.iinfo(torch.int64).max + + final_state_per_rank = [ + { + "table_0._mch_sorted_raw_ids": torch.LongTensor( + [1000, 1001, 1002, 1004] + [max_int] * 4 + ), + "table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [3, 4, 5, 6, 0, 1, 2, 7] + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7), + "table_1._mch_sorted_raw_ids": torch.LongTensor( + [2000, 2001, 2002, 2004] + [max_int] * 12 + ), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [14, 8, 9, 10, 11, 12, 13, 15] + ), + "table_1._mch_remapped_ids_mapping": torch.LongTensor( + [27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31] + ), + }, + ] + + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + initial_state_per_rank=initial_state_per_rank, + final_state_per_rank=final_state_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_sharding_zch_mc_ec_dedup(self, backend: str) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0", "feature_2"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 1000, 2000, 1001, 1000, 2001, 2002, 3000, 2000, 1000], + ), + lengths=torch.LongTensor([2, 1, 1, 1, 1, 1, 2, 0, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1002, + 1002, + 1004, + 2000, + 1002, + 2004, + 3999, + 2000, + 2000, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 0, 0, 3]), + weights=None, + ), + ] + + self._run_multi_process_test( + callable=_test_sharding_dedup, + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder( + ec_sharder=EmbeddingCollectionSharder( + use_index_dedup=False, + ) + ), + dedup_sharder=ManagedCollisionEmbeddingCollectionSharder( + ec_sharder=EmbeddingCollectionSharder( + use_index_dedup=True, + ) + ), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0", "feature_2"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 1000, 2000, 1001, 1000, 2001, 2002, 3000, 2000, 1000], + ), + lengths=torch.LongTensor([2, 1, 1, 1, 1, 1, 2, 0, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1002, + 1002, + 1004, + 2000, + 1002, + 2004, + 3999, + 2000, + 2000, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 0, 0, 3]), + weights=None, + ), + ] + + try: + self._run_multi_process_test( + callable=_test_sharding_dedup, + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder( + ec_sharder=EmbeddingCollectionSharder( + use_index_dedup=False, + ) + ), + dedup_sharder=ManagedCollisionEmbeddingCollectionSharder( + ec_sharder=EmbeddingCollectionSharder( + use_index_dedup=True, + ) + ), + backend=backend, + input_hash_size=(2**62) - 1 + 10, + ), + except AssertionError as e: + self.assertTrue("0 != 1" in str(e)) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 2, + 2, + 2, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + initial_state_per_rank = [ + { + "table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_remapped_ids_mapping": torch.arange( + start=8, end=16, dtype=torch.int64 + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + start=16, end=32, dtype=torch.int64 + ), + }, + ] + max_int = torch.iinfo(torch.int64).max + + final_state_per_rank = [ + { + "table_0._mch_sorted_raw_ids": torch.LongTensor( + [1000, 1001, 1002, 1004] + [max_int] * 4 + ), + "table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [3, 4, 5, 6, 0, 1, 2, 7] + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7), + "table_1._mch_sorted_raw_ids": torch.LongTensor( + [2000, 2001, 2002, 2004] + [max_int] * 12 + ), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [14, 8, 9, 10, 11, 12, 13, 15] + ), + "table_1._mch_remapped_ids_mapping": torch.LongTensor( + [27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31] + ), + }, + ] + + self._run_multi_process_test( + callable=_test_in_place_embd_weight_update, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + initial_state_per_rank=initial_state_per_rank, + final_state_per_rank=final_state_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py new file mode 100644 index 000000000..a24caf2cc --- /dev/null +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest +from typing import Dict, Final, List, Optional, Tuple + +import torch +import torch.nn as nn +from hypothesis import given, settings, strategies as st +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection +from torchrec.distributed.mc_embeddingbag import ( + ManagedCollisionEmbeddingBagCollectionSharder, + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding_plan import construct_module_sharding_plan, row_wise + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +# Global constants for testing ShardedManagedCollisionEmbeddingBagCollection + +WORLD_SIZE = 2 + +# Input KeyedJaggedTensors for each rank in distributed tests +embedding_bag_config: Final[List[EmbeddingBagConfig]] = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), +] + +# Expected remapped outputs per iteration per rank for validation +kjt_input_per_rank: Final[List[KeyedJaggedTensor]] = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 1, + 1, + 1, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), +] + +kjt_out_per_iter_per_rank: Final[List[List[KeyedJaggedTensor]]] = [ + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], +] + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingBagConfig], + device: torch.device, + return_remapped: bool = False, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + mc_modules["table_0"] = MCHManagedCollisionModule( + zch_size=tables[0].num_embeddings, + input_hash_size=4000, + device=device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + mc_modules["table_1"] = MCHManagedCollisionModule( + zch_size=tables[1].num_embeddings, + device=device, + input_hash_size=4000, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self._mc_ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[torch.Tensor, Optional[Dict[str, JaggedTensor]]]: + if self._return_remapped: + ebc_out, remapped_ids_out = self._mc_ebc(kjt) + else: + ebc_out = self._mc_ebc(kjt) + remapped_ids_out = None + pred = torch.cat( + [ebc_out[key] for key in ["feature_0", "feature_1"]], + dim=1, + ) + loss = pred.mean() + return loss, remapped_ids_out + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_0" + ].weight, + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ebc, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ebc": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ebc, ShardedManagedCollisionEmbeddingBagCollection + ) + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ebc._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + + +def _test_sharding_and_remapping( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_0" + ].weight, + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ebc, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ebc": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ebc, ShardedManagedCollisionEmbeddingBagCollection + ) + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_bag_collection`. + sharded_sparse_arch._mc_ebc._embedding_bag_collection, + ShardedEmbeddingBagCollection, + ) + assert ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_bag_collection`. + sharded_sparse_arch._mc_ebc._embedding_bag_collection._has_uninitialized_input_dist + is False + ) + assert ( + not hasattr( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_embedding_bag_collection`. + sharded_sparse_arch._mc_ebc._embedding_bag_collection, + "_input_dists", + ) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_embedding_bag_collection`. + or len(sharded_sparse_arch._mc_ebc._embedding_bag_collection._input_dists) + == 0 + ) + + assert isinstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_managed_collision_collection`. + sharded_sparse_arch._mc_ebc._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + + test_state_dict = sharded_sparse_arch.state_dict() + sharded_sparse_arch.load_state_dict(test_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss1.backward() + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + # TODO: validate embedding rows, and eviction + + +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + allow_in_place_embed_weight_update: bool = True, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_0" + ].weight, + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ebc, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ebc": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + test_state_dict = sharded_sparse_arch.state_dict() + sharded_sparse_arch.load_state_dict(test_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + +@skip_if_asan_class +class ShardedMCEmbeddingBagCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_uneven_sharding(self, backend: str) -> None: + WORLD_SIZE = 2 + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=17, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=33, + ), + ] + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_even_sharding(self, backend: str) -> None: + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=None) + def test_sharding_zch_mc_ebc(self, backend: str) -> None: + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_bag_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: + + self._run_multi_process_test( + callable=_test_in_place_embd_weight_update, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_bag_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py deleted file mode 100644 index b2e77a3e0..000000000 --- a/torchrec/distributed/tests/test_model_parallel.py +++ /dev/null @@ -1,1028 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -from collections import defaultdict, OrderedDict -from typing import Any, cast, Dict, List, Optional, Tuple, Type - -import hypothesis.strategies as st -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -import torchrec.distributed as trec_dist -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from hypothesis import assume, given, settings, Verbosity -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.embeddingbag import ( - EmbeddingBagCollectionSharder, - EmbeddingBagSharder, - ShardedEmbeddingBagCollection, -) -from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig -from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection - -from torchrec.distributed.model_parallel import ( - DistributedModelParallel, - get_default_sharders, -) -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - ParameterConstraints, - Topology, -) -from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN -from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared -from torchrec.distributed.test_utils.test_sharding import ( - create_test_sharder, - SharderType, -) -from torchrec.distributed.types import ( - ModuleSharder, - ShardedTensor, - ShardingEnv, - ShardingType, -) -from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection -from torchrec.test_utils import get_free_port, skip_if_asan_class - - -@skip_if_asan_class -class ModelParallelTest(ModelParallelTestShared): - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.ROW_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - qcomms_config=st.sampled_from( - [ - None, - QCommsConfig( - forward_precision=CommType.FP16, backward_precision=CommType.BF16 - ), - ] - ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), - }, - ] - ), - variable_batch_size=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) - def test_sharding_nccl_rw( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - variable_batch_size: bool, - ) -> None: - assume( - sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value - or not variable_batch_size - ) - self._test_sharding( - sharders=[ - cast( - ModuleSharder[nn.Module], - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - qcomms_config=qcomms_config, - device=torch.device("cuda"), - variable_batch_size=variable_batch_size, - ), - ), - ], - qcomms_config=qcomms_config, - backend="nccl", - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=variable_batch_size, - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.DATA_PARALLEL.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - ] - ), - apply_optimizer_in_backward_config=st.sampled_from([None]), - # TODO - need to enable optimizer overlapped behavior for data_parallel tables - ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_sharding_nccl_dp( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - ) -> None: - - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder(sharder_type, sharding_type, kernel_type), - ], - backend="nccl", - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - # SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.COLUMN_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - # EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - qcomms_config=st.sampled_from( - [ - None, - QCommsConfig( - forward_precision=CommType.FP16, backward_precision=CommType.BF16 - ), - ] - ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), - }, - ] - ), - variable_batch_size=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_nccl_cw( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - variable_batch_size: bool, - ) -> None: - assume( - sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value - or not variable_batch_size - ) - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - qcomms_config=qcomms_config, - device=torch.device("cuda"), - variable_batch_size=variable_batch_size, - ), - ], - backend="nccl", - qcomms_config=qcomms_config, - constraints={ - table.name: ParameterConstraints(min_partition=4) - for table in self.tables - }, - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=variable_batch_size, - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - # SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - qcomms_config=st.sampled_from( - [ - # None, - QCommsConfig( - forward_precision=CommType.FP16, - backward_precision=CommType.BF16, - ), - ] - ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), - }, - ] - ), - variable_batch_size=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_nccl_tw( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - variable_batch_size: bool, - ) -> None: - assume( - sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value - or not variable_batch_size - ) - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - qcomms_config=qcomms_config, - device=torch.device("cuda"), - variable_batch_size=variable_batch_size, - ), - ], - backend="nccl", - qcomms_config=qcomms_config, - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=variable_batch_size, - ) - - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - # TODO: enable it with correct semantics, see T104397332 - # SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - qcomms_config=st.sampled_from( - [ - None, - # On gloo, BF16 is not supported as dtype. - QCommsConfig( - forward_precision=CommType.FP16, backward_precision=CommType.FP16 - ), - ] - ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), - }, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_gloo_tw( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - ) -> None: - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - qcomms_config=qcomms_config, - device=torch.device("cpu"), - ), - ], - qcomms_config=qcomms_config, - backend="gloo", - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - ) - - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - # TODO: enable it with correct semantics, see T104397332 - # SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.COLUMN_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - qcomms_config=st.sampled_from( - [ - None, - # On gloo, BF16 is not supported as dtype. - QCommsConfig( - forward_precision=CommType.FP16, - backward_precision=CommType.FP16, - ), - ] - ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), - }, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_gloo_cw( - self, - sharder_type: str, - sharding_type: str, - kernel_type: str, - qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - ) -> None: - world_size = 4 - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - qcomms_config=qcomms_config, - device=torch.device("cpu"), - ), - ], - qcomms_config=qcomms_config, - backend="gloo", - world_size=world_size, - constraints={ - table.name: ParameterConstraints(min_partition=4) - for table in self.tables - }, - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - ) - - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG.value, - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.DATA_PARALLEL.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - # TODO dp+batch_fused is numerically buggy in cpu - # EmbeddingComputeKernel.FUSED.value, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_gloo_dp( - self, sharder_type: str, sharding_type: str, kernel_type: str - ) -> None: - self._test_sharding( - # pyre-ignore[6] - sharders=[ - create_test_sharder(sharder_type, sharding_type, kernel_type), - ], - backend="gloo", - ) - - -class ModelParallelSparseOnlyTest(unittest.TestCase): - def test_sharding_ebc_as_top_level(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" - else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - - embedding_dim = 128 - num_embeddings = 256 - ebc = EmbeddingBagCollection( - device=torch.device("meta"), - tables=[ - EmbeddingBagConfig( - name="large_table", - embedding_dim=embedding_dim, - num_embeddings=num_embeddings, - feature_names=["my_feature"], - pooling=PoolingType.SUM, - ), - ], - ) - - model = DistributedModelParallel(ebc, device=curr_device) - - self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection)) - dist.destroy_process_group() - - def test_sharding_fused_ebc_as_top_level(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" - else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - - embedding_dim = 128 - num_embeddings = 256 - ebc = FusedEmbeddingBagCollection( - device=torch.device("meta"), - tables=[ - EmbeddingBagConfig( - name="large_table", - embedding_dim=embedding_dim, - num_embeddings=num_embeddings, - feature_names=["my_feature"], - pooling=PoolingType.SUM, - ), - ], - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - ) - - model = DistributedModelParallel(ebc, device=curr_device) - - self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection)) - dist.destroy_process_group() - - -class ModelParallelStateDictTest(unittest.TestCase): - def setUp(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - if torch.cuda.is_available(): - self.device = torch.device("cuda:0") - backend = "nccl" - torch.cuda.set_device(self.device) - else: - self.device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - - num_features = 4 - num_weighted_features = 2 - self.batch_size = 3 - self.num_float_features = 10 - - self.tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_features) - ] - self.weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - - def tearDown(self) -> None: - dist.destroy_process_group() - del os.environ["NCCL_SOCKET_IFNAME"] - super().tearDown() - - def _generate_dmps_and_batch( - self, - sharders: Optional[List[ModuleSharder[nn.Module]]] = None, - constraints: Optional[Dict[str, trec_dist.planner.ParameterConstraints]] = None, - ) -> Tuple[List[DistributedModelParallel], ModelInput]: - - if constraints is None: - constraints = {} - if sharders is None: - sharders = get_default_sharders() - - _, local_batch = ModelInput.generate( - batch_size=self.batch_size, - world_size=1, - num_float_features=self.num_float_features, - tables=self.tables, - weighted_tables=self.weighted_tables, - ) - batch = local_batch[0].to(self.device) - - dmps = [] - pg = dist.GroupMember.WORLD - assert pg is not None, "Process group is not initialized" - env = ShardingEnv.from_process_group(pg) - - planner = EmbeddingShardingPlanner( - topology=Topology( - local_world_size=trec_dist.comm.get_local_size(env.world_size), - world_size=env.world_size, - compute_device=self.device.type, - ), - constraints=constraints, - ) - - for _ in range(2): - # Create two TestSparseNN modules, wrap both in DMP - m = TestSparseNN( - tables=self.tables, - num_float_features=self.num_float_features, - weighted_tables=self.weighted_tables, - dense_device=self.device, - sparse_device=torch.device("meta"), - ) - if pg is not None: - plan = planner.collective_plan(m, sharders, pg) - else: - plan = planner.plan(m, sharders) - - dmp = DistributedModelParallel( - module=m, - init_data_parallel=False, - device=self.device, - sharders=sharders, - plan=plan, - ) - - with torch.no_grad(): - dmp(batch) - dmp.init_data_parallel() - dmps.append(dmp) - return (dmps, batch) - - def test_parameter_init(self) -> None: - class MyModel(nn.Module): - def __init__(self, device: str, val: float) -> None: - super().__init__() - self.p = nn.Parameter( - torch.empty(3, dtype=torch.float32, device=device) - ) - self.val = val - self.reset_parameters() - - def reset_parameters(self) -> None: - nn.init.constant_(self.p, self.val) - - dist.destroy_process_group() - dist.init_process_group(backend="gloo") - - # Check that already allocated parameters are left 'as is'. - cpu_model = MyModel(device="cpu", val=3.2) - sharded_model = DistributedModelParallel( - cpu_model, - ) - sharded_param = next(sharded_model.parameters()) - np.testing.assert_array_equal( - np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy() - ) - - # Check that parameters over 'meta' device are allocated and initialized. - meta_model = MyModel(device="meta", val=7.5) - sharded_model = DistributedModelParallel( - meta_model, - ) - sharded_param = next(sharded_model.parameters()) - np.testing.assert_array_equal( - np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy() - ) - - def test_meta_device_dmp_state_dict(self) -> None: - # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - env = ShardingEnv.from_process_group(dist.GroupMember.WORLD) - - m1 = TestSparseNN( - tables=self.tables, - num_float_features=self.num_float_features, - weighted_tables=self.weighted_tables, - dense_device=self.device, - ) - # dmp with real device - dmp1 = DistributedModelParallel( - module=m1, - init_data_parallel=False, - init_parameters=False, - sharders=get_default_sharders(), - device=self.device, - env=env, - plan=EmbeddingShardingPlanner( - topology=Topology( - world_size=env.world_size, compute_device=self.device.type - ) - ).plan(m1, get_default_sharders()), - ) - - m2 = TestSparseNN( - tables=self.tables, - num_float_features=self.num_float_features, - weighted_tables=self.weighted_tables, - dense_device=self.device, - ) - # dmp with meta device - dmp2 = DistributedModelParallel( - module=m2, - init_data_parallel=False, - init_parameters=False, - sharders=get_default_sharders(), - device=torch.device("meta"), - env=env, - plan=EmbeddingShardingPlanner( - topology=Topology( - world_size=env.world_size, compute_device=self.device.type - ) - ).plan(m2, get_default_sharders()), - ) - - sd1 = dmp1.state_dict() - for key, v2 in dmp2.state_dict().items(): - v1 = sd1[key] - if isinstance(v2, nn.parameter.UninitializedParameter) and isinstance( - v1, nn.parameter.UninitializedParameter - ): - continue - if isinstance(v2, ShardedTensor): - self.assertTrue(isinstance(v1, ShardedTensor)) - assert len(v2.local_shards()) == 1 - dst = v2.local_shards()[0].tensor - else: - dst = v2 - if isinstance(v1, ShardedTensor): - assert len(v1.local_shards()) == 1 - src = v1.local_shards()[0].tensor - else: - src = v1 - self.assertEqual(src.size(), dst.size()) - - # pyre-ignore[56] - @given( - sharders=st.sampled_from( - [ - [EmbeddingBagCollectionSharder()], - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_load_state_dict(self, sharders: List[ModuleSharder[nn.Module]]) -> None: - models, batch = self._generate_dmps_and_batch(sharders) - m1, m2 = models - - # load the second's (m2's) with the first (m1's) state_dict - m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) - - # validate the models are equivalent - with torch.no_grad(): - loss1, pred1 = m1(batch) - loss2, pred2 = m2(batch) - self.assertTrue(torch.equal(loss1, loss2)) - self.assertTrue(torch.equal(pred1, pred2)) - sd1 = m1.state_dict() - for key, value in m2.state_dict().items(): - v2 = sd1[key] - if isinstance(value, ShardedTensor): - assert len(value.local_shards()) == 1 - dst = value.local_shards()[0].tensor - else: - dst = value - if isinstance(v2, ShardedTensor): - assert len(v2.local_shards()) == 1 - src = v2.local_shards()[0].tensor - else: - src = v2 - self.assertTrue(torch.equal(src, dst)) - - # pyre-ignore[56] - @given( - sharders=st.sampled_from( - [ - [EmbeddingBagCollectionSharder()], - [EmbeddingBagSharder()], - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_load_state_dict_prefix( - self, sharders: List[ModuleSharder[nn.Module]] - ) -> None: - (m1, m2), batch = self._generate_dmps_and_batch(sharders) - - # load the second's (m2's) with the first (m1's) state_dict - m2.load_state_dict( - cast("OrderedDict[str, torch.Tensor]", m1.state_dict(prefix="alpha")), - prefix="alpha", - ) - - # validate the models are equivalent - sd1 = m1.state_dict() - for key, value in m2.state_dict().items(): - v2 = sd1[key] - if isinstance(value, ShardedTensor): - assert len(value.local_shards()) == 1 - dst = value.local_shards()[0].tensor - else: - dst = value - if isinstance(v2, ShardedTensor): - assert len(v2.local_shards()) == 1 - src = v2.local_shards()[0].tensor - else: - src = v2 - self.assertTrue(torch.equal(src, dst)) - - # pyre-fixme[56] - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - # EmbeddingComputeKernel.DENSE.value, - EmbeddingComputeKernel.FUSED.value, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) - def test_params_and_buffers( - self, sharder_type: str, sharding_type: str, kernel_type: str - ) -> None: - sharders = [ - create_test_sharder(sharder_type, sharding_type, kernel_type), - ] - # pyre-ignore[6] - (m, _), batch = self._generate_dmps_and_batch(sharders=sharders) - print(f"Sharding Plan: {m._plan}") - state_dict_keys = set(m.state_dict().keys()) - param_keys = {key for (key, _) in m.named_parameters()} - buffer_keys = {key for (key, _) in m.named_buffers()} - self.assertEqual(state_dict_keys, {*param_keys, *buffer_keys}) - - # pyre-ignore - @given( - sharder_type=st.sampled_from( - [ - SharderType.EMBEDDING_BAG_COLLECTION.value, - ] - ), - sharding_type=st.sampled_from( - [ - ShardingType.COLUMN_WISE.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.FUSED.value, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_load_state_dict_cw_multiple_shards( - self, sharder_type: str, sharding_type: str, kernel_type: str - ) -> None: - sharders = [ - cast( - ModuleSharder[nn.Module], - create_test_sharder( - sharder_type, - sharding_type, - kernel_type, - fused_params={ - "learning_rate": 0.2, - "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, - }, - ), - ), - ] - - constraints = defaultdict(lambda: trec_dist.planner.ParameterConstraints()) - num_cw_shards_per_table = {} - for table in self.tables + self.weighted_tables: - constraints[table.name].min_partition = 4 - num_cw_shards_per_table[table.name] = table.embedding_dim // 4 - - (m1, m2), batch = self._generate_dmps_and_batch( - sharders, constraints=constraints - ) - - # load the second's (m2's) with the first (m1's) state_dict - m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) - - # load optimizer state dict - - # Check to see that we can load optimizer state - src_optimizer = m1.fused_optimizer - dst_optimizer = m2.fused_optimizer - - src_optimizer_state_dict = src_optimizer.state_dict() - dst_optimizer_state_dict = dst_optimizer.state_dict() - m2.fused_optimizer.load_state_dict(src_optimizer_state_dict) - - # validate the models are equivalent - loss1, pred1 = m1(batch) - loss2, pred2 = m2(batch) - self.assertTrue(torch.equal(loss1, loss2)) - self.assertTrue(torch.equal(pred1, pred2)) - - sd1 = m1.state_dict() - for key, value in m2.state_dict().items(): - table_name = key.split(".")[-2] - v2 = sd1[key] - if isinstance(value, ShardedTensor): - self.assertEqual( - len(value.local_shards()), num_cw_shards_per_table[table_name] - ) - dst = value.local_shards()[0].tensor - else: - dst = value - - if isinstance(v2, ShardedTensor): - self.assertEqual( - len(value.local_shards()), num_cw_shards_per_table[table_name] - ) - - for src_local_shard, dst_local_shard in zip( - value.local_shards(), v2.local_shards() - ): - self.assertTrue( - torch.equal(src_local_shard.tensor, dst_local_shard.tensor) - ) - else: - src = v2 - self.assertTrue(torch.equal(src, dst)) - - for param_name, dst_param_group in dst_optimizer_state_dict.items(): - src_param_group = src_optimizer_state_dict[param_name] - - for state_key, dst_opt_state in dst_param_group.items(): - table_name = state_key.split(".")[-2] - src_opt_state = src_param_group[state_key] - if isinstance(dst_opt_state, ShardedTensor): - self.assertIsInstance(src_param_group[state_key], ShardedTensor) - - self.assertEqual( - len(dst_opt_state.local_shards()), - num_cw_shards_per_table[table_name], - ) - - self.assertEqual( - len(src_opt_state.local_shards()), - num_cw_shards_per_table[table_name], - ) - - for src_local_shard, dst_local_shard in zip( - src_opt_state.local_shards(), dst_opt_state.local_shards() - ): - self.assertTrue( - torch.equal(src_local_shard.tensor, dst_local_shard.tensor) - ) - elif isinstance(dst_opt_state, torch.Tensor): - self.assertIsInstance(src_opt_state, torch.Tensor) diff --git a/torchrec/distributed/tests/test_model_parallel_gloo.py b/torchrec/distributed/tests/test_model_parallel_gloo.py new file mode 100644 index 000000000..e3c9b636b --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_gloo.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSparseOnlyBase, + ModelParallelStateDictBase, +) + +# CPU tests for Gloo. + + +class ModelParallelTestGloo(ModelParallelBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) + + +class ModelParallelStateDictTestGloo(ModelParallelStateDictBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) + + +class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) diff --git a/torchrec/distributed/tests/test_model_parallel_gloo_gpu.py b/torchrec/distributed/tests/test_model_parallel_gloo_gpu.py new file mode 100644 index 000000000..70c550762 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_gloo_gpu.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase + +# GPU tests for Gloo. + + +class ModelParallelTestGloo(ModelParallelBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) diff --git a/torchrec/distributed/tests/test_model_parallel_gloo_gpu_single_rank.py b/torchrec/distributed/tests/test_model_parallel_gloo_gpu_single_rank.py new file mode 100644 index 000000000..2a30821f0 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_gloo_gpu_single_rank.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSparseOnlyBase, + ModelParallelStateDictBase, +) + +# Single rank GPU tests for Gloo. + + +class ModelParallelStateDictTestGloo(ModelParallelStateDictBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) + + +class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase): + def setUp(self, backend: str = "gloo") -> None: + super().setUp(backend=backend) diff --git a/torchrec/distributed/tests/test_model_parallel_hierarchical.py b/torchrec/distributed/tests/test_model_parallel_hierarchical.py index c364c12bb..9d2f40be5 100644 --- a/torchrec/distributed/tests/test_model_parallel_hierarchical.py +++ b/torchrec/distributed/tests/test_model_parallel_hierarchical.py @@ -5,15 +5,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import os import unittest from typing import Any, Dict, Optional, Tuple, Type import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import assume, given, settings, strategies as st, Verbosity from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.test_utils.test_model import ( + TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, ) @@ -21,8 +26,10 @@ from torchrec.distributed.test_utils.test_sharding import ( create_test_sharder, SharderType, + sharding_single_rank_test, ) from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import PoolingType from torchrec.test_utils import skip_if_asan_class @@ -43,14 +50,13 @@ class ModelParallelHierarchicalTest(ModelParallelTestShared): @given( sharder_type=st.sampled_from( [ - SharderType.EMBEDDING_BAG.value, + # SharderType.EMBEDDING_BAG.value, SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), sharding_type=st.just(ShardingType.TABLE_ROW_WISE.value), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -67,14 +73,15 @@ class ModelParallelHierarchicalTest(ModelParallelTestShared): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), ) - @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_nccl_twrw( self, sharder_type: str, @@ -86,6 +93,7 @@ def test_sharding_nccl_twrw( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, + pooling: PoolingType, ) -> None: # Dense kernels do not have overlapped optimizer behavior yet assume( @@ -96,6 +104,9 @@ def test_sharding_nccl_twrw( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size ) + # Make sure detail debug will work with non-even collective + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + self._test_sharding( # pyre-ignore[6] sharders=[ @@ -105,7 +116,6 @@ def test_sharding_nccl_twrw( kernel_type, qcomms_config=qcomms_config, device=torch.device("cuda"), - variable_batch_size=variable_batch_size, ), ], backend="nccl", @@ -114,6 +124,7 @@ def test_sharding_nccl_twrw( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + pooling=pooling, ) @unittest.skipIf( @@ -124,7 +135,7 @@ def test_sharding_nccl_twrw( @given( sharder_type=st.sampled_from( [ - SharderType.EMBEDDING_BAG.value, + # SharderType.EMBEDDING_BAG.value, SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), @@ -135,7 +146,6 @@ def test_sharding_nccl_twrw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -152,13 +162,14 @@ def test_sharding_nccl_twrw( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) def test_sharding_nccl_twcw( self, sharder_type: str, @@ -169,12 +180,17 @@ def test_sharding_nccl_twcw( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], + variable_batch_size: bool, ) -> None: # Dense kernels do not have overlapped optimizer behavior yet assume( apply_optimizer_in_backward_config is None or kernel_type != EmbeddingComputeKernel.DENSE.value ) + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) world_size = 4 self._test_sharding( # pyre-ignore[6] @@ -196,6 +212,50 @@ def test_sharding_nccl_twcw( }, qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least three GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_ROW_WISE.value, + ] + ), + variable_batch_per_feature=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_sharding_empty_rank( + self, sharding_type: str, variable_batch_per_feature: bool + ) -> None: + self._build_tables_and_groups() + table = self.tables[0] + embedding_groups = {"group_0": table.feature_names} + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=4, + local_size=2, + model_class=TestSparseNN, + tables=[table], + embedding_groups=embedding_groups, + sharders=[ + create_test_sharder( + SharderType.EMBEDDING_BAG_COLLECTION.value, + sharding_type, + EmbeddingComputeKernel.FUSED.value, + device=torch.device("cuda"), + ) + ], + optim=EmbOptimType.EXACT_SGD, + backend="nccl", + constraints={table.name: ParameterConstraints(min_partition=4)}, + variable_batch_size=True, + variable_batch_per_feature=variable_batch_per_feature, + weighted_tables=None, ) @unittest.skipIf( @@ -212,7 +272,6 @@ def test_sharding_nccl_twcw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -228,7 +287,7 @@ def test_sharding_nccl_twcw( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -282,7 +341,6 @@ def test_embedding_tower_nccl( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -298,7 +356,7 @@ def test_embedding_tower_nccl( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -337,3 +395,46 @@ def test_embedding_tower_collection_nccl( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + local_size=st.sampled_from([2]), + global_constant_batch=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_sharding_variable_batch( + self, + sharding_type: str, + local_size: int, + global_constant_batch: bool, + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + SharderType.EMBEDDING_BAG_COLLECTION.value, + sharding_type, + EmbeddingComputeKernel.FUSED.value, + device=torch.device("cuda"), + ), + ], + backend="nccl", + world_size=4, + local_size=local_size, + variable_batch_per_feature=True, + has_weighted_tables=False, + global_constant_batch=global_constant_batch, + pooling=pooling, + ) diff --git a/examples/datasets/__init__.py b/torchrec/distributed/tests/test_model_parallel_nccl.py similarity index 59% rename from examples/datasets/__init__.py rename to torchrec/distributed/tests/test_model_parallel_nccl.py index 75e5c37d1..7e1e480a5 100644 --- a/examples/datasets/__init__.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl.py @@ -5,4 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import criteo_dataframes # noqa +# pyre-strict + +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelBase + + +class ModelParallelTestNccl(ModelParallelBase): + pass diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py new file mode 100644 index 000000000..0ea359f89 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSparseOnlyBase, + ModelParallelStateDictBase, +) + + +class ModelParallelStateDictTestNccl(ModelParallelStateDictBase): + pass + + +class ModelParallelSparseOnlyTestNccl(ModelParallelSparseOnlyBase): + pass diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py new file mode 100644 index 000000000..ccff007d1 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -0,0 +1,777 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast, List, OrderedDict, Union + +import torch +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import given, settings, strategies as st, Verbosity +from torchrec.distributed.batched_embedding_kernel import ( + KeyValueEmbedding, + KeyValueEmbeddingBag, +) +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + ShardedEmbeddingTable, +) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSingleRankBase, +) +from torchrec.distributed.test_utils.test_sharding import ( + copy_state_dict, + create_test_sharder, + SharderType, +) +from torchrec.distributed.tests.test_sequence_model import ( + TestEmbeddingCollectionSharder, + TestSequenceSparseNN, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, +) + + +def _load_split_embedding_weights( + emb_module: Union[KeyValueEmbedding, KeyValueEmbeddingBag], + weights: List[torch.Tensor], +) -> None: + """ + Util function to set the weights of SSD TBE. + """ + embedding_tables: List[ShardedEmbeddingTable] = emb_module.config.embedding_tables + + assert len(weights) == len( + embedding_tables + ), "Expect length of weights to be equal to number of embedding tables. " + + cum_sum = 0 + for table_id, (table, weight) in enumerate(zip(embedding_tables, weights)): + # load weights for SSD TBE + height = weight.shape[0] + shard_shape = table.local_rows, table.local_cols + assert shard_shape == weight.shape, "Expect shard shape to match tensor shape." + assert weight.device == torch.device("cpu"), "Weight has to be on CPU." + emb_module.emb_module.ssd_db.set_cuda( + torch.arange(cum_sum, cum_sum + height, dtype=torch.int64), + weight, + torch.as_tensor([height]), + table_id, + ) + cum_sum += height + + +class KeyValueModelParallelTest(ModelParallelSingleRankBase): + def _create_tables(self) -> None: + num_features = 4 + self.tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=256, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m1.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m2.module.sparse.ebc._lookups, + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = {KeyValueEmbeddingBag, KeyValueEmbedding} + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + + emb1_kv = dict( + emb_module1.get_named_split_embedding_weights_snapshot() + ) + for ( + k, + v, + ) in emb_module2.get_named_split_embedding_weights_snapshot(): + v1 = emb1_kv.get(k) + v1_full_tensor = v1.full_tensor() + + # write value into ssd for both emb module for later comparison + v.wrapped.set_range( + 0, 0, v1_full_tensor.size(0), v1_full_tensor + ) + v1.wrapped.set_range( + 0, 0, v1_full_tensor.size(0), v1_full_tensor + ) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_load_state_dict( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_tbe_numerical_accuracy( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Make sure it produces same numbers as normal TBE. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + base_kernel_type, # base kernel type + fused_params=fused_params, + ), + ), + ] + ssd_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + ssd_constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (ssd_model, _), batch = self._generate_dmps_and_batch( + ssd_sharders, constraints=ssd_constraints + ) + + # load state dict for dense modules + copy_state_dict( + ssd_model.state_dict(), fused_model.state_dict(), exclude_predfix="sparse" + ) + + # for this to work, we expect the order of lookups to be the same + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + assert len(fused_model.module.sparse.ebc._lookups) == len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups + ), "Expect same number of lookups" + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + fused_model.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups, + ): + assert len(fused_lookup._emb_modules) == len( + ssd_lookup._emb_modules + ), "Expect same number of emb modules" + for fused_emb_module, ssd_emb_module in zip( + fused_lookup._emb_modules, ssd_lookup._emb_modules + ): + weights = fused_emb_module.split_embedding_weights() + weights = [weight.to("cpu") for weight in weights] + _load_split_embedding_weights(ssd_emb_module, weights) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + ssd_emb_module.purge() + + if is_training: + self._train_models(fused_model, ssd_model, batch) + self._eval_models( + fused_model, ssd_model, batch, is_deterministic=is_deterministic + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_fused_optimizer( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + base_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, batch = self._generate_dmps_and_batch( + base_sharders, # pyre-ignore + constraints=constraints, + ) + base_model, _ = models + + test_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, _ = self._generate_dmps_and_batch( + test_sharders, # pyre-ignore + constraints=constraints, + ) + test_model, _ = models + + # load state dict for dense modules + test_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", base_model.state_dict()) + ) + self._copy_ssd_emb_modules(base_model, test_model) + + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + + # change learning rate for test_model + fused_opt = test_model.fused_optimizer + # pyre-ignore + fused_opt.param_groups[0]["lr"] = 0.2 + fused_opt.zero_grad() + + if is_training: + self._train_models(base_model, test_model, batch) + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + self._compare_models(base_model, test_model, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + fused_first=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_mixed_kernels( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + fused_first: bool, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=( + [base_kernel_type] if i % 2 == fused_first else [kernel_type] + ), + ) + for i, table in enumerate(self.tables) + } + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + sharders = [ + EmbeddingBagCollectionSharder(fused_params=fused_params), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + table_wise_first=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_mixed_sharding_types( + self, + sharder_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + table_wise_first: bool, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=( + [ShardingType.TABLE_WISE.value] + if i % 2 == table_wise_first + else [ShardingType.ROW_WISE.value] + ), + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + sharders = [ + EmbeddingBagCollectionSharder(fused_params=fused_params), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + +class KeyValueSequenceModelParallelStateDictTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + self.shared_features = [] + self.embedding_groups = {} + + super().setUp(backend=backend) + + def _create_tables(self) -> None: + num_features = 4 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + ) + for i in range(shared_features) + ] + + self.tables += initial_tables + shared_features_tables + self.shared_features += [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups["group_0"] = [ + (f"{feature}@{table.name}" if feature in self.shared_features else feature) + for table in self.tables + for feature in table.feature_names + ] + + def _create_model(self) -> nn.Module: + return TestSequenceSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + embedding_groups=self.embedding_groups, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + m1.module.sparse.ec._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + m2.module.sparse.ec._lookups, + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = {KeyValueEmbeddingBag, KeyValueEmbedding} + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + + weights = emb_module1.emb_module.debug_split_embedding_weights() + # need to set emb_module1 as well, since otherwise emb_module1 would + # produce a random debug_split_embedding_weights everytime + _load_split_embedding_weights(emb_module1, weights) + _load_split_embedding_weights(emb_module2, weights) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_load_state_dict( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ), + ), + ] + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + + +# TODO: remove after development is done +def main() -> None: + unittest.main() + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py new file mode 100644 index 000000000..e1738b598 --- /dev/null +++ b/torchrec/distributed/tests/test_pt2.py @@ -0,0 +1,1050 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors + +import copy +import itertools +import sys +import unittest +from enum import auto, Enum +from typing import Any, Dict, List, Tuple + +import torch +from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from fbgemm_gpu.permute_pooled_embedding_modules_split import ( + PermutePooledEmbeddingsSplit, +) +from fbgemm_gpu.split_embedding_utils import get_table_batched_offsets_from_dense +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + EmbeddingLocation, + SplitTableBatchedEmbeddingBagsCodegen, +) +from hypothesis import given, settings, strategies as st +from torch._dynamo.testing import reduce_to_scalar_loss +from torchrec.distributed.test_utils.infer_utils import ( + KJTInputExportDynamicShapeWrapper, + KJTInputExportWrapperWithStrides, + TestQuantFPEBCSharder, +) +from torchrec.pt2.utils import ( + deregister_fake_classes, + kjt_for_pt2_tracing, + register_fake_classes, +) +from torchrec.sparse.jagged_tensor import _kt_regroup_arguments + +try: + # pyre-ignore + from caffe2.test.inductor.test_aot_inductor import AOTIRunnerUtil +except (unittest.SkipTest, ImportError): + if __name__ == "__main__": + sys.exit(0) + + +from fbgemm_gpu import sparse_ops # noqa: F401, E402 +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fused_params import FUSED_PARAM_BOUNDS_CHECK_MODE +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.infer_utils import ( + assert_close, + create_test_model_ebc_only, + KJTInputExportWrapper, + prep_inputs, + replace_registered_tbes_with_mock_tbes, + replace_sharded_quant_modules_tbes_with_mock_tbes, + TestQuantEBCSharder, +) +from torchrec.distributed.types import BoundsCheckMode, ShardingEnv, ShardingType +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) + + +def make_kjt( + values: List[int], lengths: List[int], device: str = "cpu" +) -> KeyedJaggedTensor: + values_tensor = torch.tensor(values, dtype=torch.int32, device=device) + lengths_tensor = torch.tensor(lengths, dtype=torch.int32, device=device) + weights_tensor = torch.randn(len(values), dtype=torch.float32, device=device) + torch._check(torch.sum(lengths_tensor).item() == values_tensor.size(0)) + kjt = KeyedJaggedTensor( + keys=[f"key{i}" for i in range(len(lengths))], + values=values_tensor, + lengths=lengths_tensor, + weights=weights_tensor, + ) + return kjt + + +def kjt_module_kjt_inputs_with_strides(kjt: KeyedJaggedTensor) -> Tuple: + return ( + kjt._values, + kjt._lengths, + kjt._stride_per_key_per_rank, + ) + + +def _sharded_quant_ebc_model( + local_device: str = "cuda", + compute_device: str = "cuda", + feature_processor: bool = False, +) -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]]: + num_embeddings = 256 + emb_dim = 12 + world_size = 2 + batch_size = 4 + + local_device = torch.device(local_device) + mi = create_test_model_ebc_only( + num_embeddings, + emb_dim, + world_size, + batch_size, + num_features=2, + num_weighted_features=1, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + compute_device=compute_device, + feature_processor=feature_processor, + ) + input_kjts = [ + inp.to(local_device).idlist_features + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharding_type: ShardingType = ShardingType.TABLE_WISE + + fused_params = { + FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, + } + if feature_processor: + sharder = TestQuantFPEBCSharder( + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + fused_params=fused_params, + ) + else: + sharder = TestQuantEBCSharder( + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + fused_params=fused_params, + ) + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = _shard_modules( + module=mi.quant_model, + # pyre-ignore + sharders=[sharder], + # Always shard on meta + device=torch.device("meta"), + plan=plan, + # pyre-ignore + env=ShardingEnv.from_local(world_size=mi.topology.world_size, rank=0), + ) + + model: torch.nn.Module = KJTInputExportWrapper(sharded_model, input_kjts[0].keys()) + return model, input_kjts + + +class _TestType(Enum): + EXPORT = auto() + DYNAMO_COMPILE = auto() + + +# pyre-ignore +def _copy_input_tensors(t, device): + if isinstance(t, torch.Tensor): + ret = t.detach().clone().to(device) + if ret.dtype in [torch.float, torch.double]: + ret.requires_grad = True + ret.retain_grad() + return ret + elif isinstance(t, (list, tuple)): + return [_copy_input_tensors(_t, device) for _t in t] + elif isinstance(t, int): + return t + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _grad_detach_clone(t): + if isinstance(t, torch.Tensor): + # pyre-ignore + if t.grad is None: + return None + return t.grad.detach().clone() + elif isinstance(t, (list, tuple)): + return [_grad_detach_clone(_t) for _t in t] + elif isinstance(t, int): + return t + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _assert_close(actual, expected) -> None: + if actual is None and expected is None: + return + + if isinstance(expected, torch.Tensor): + assert isinstance(actual, torch.Tensor) + torch.testing.assert_close(actual, expected, atol=1e-3, rtol=1e-3) + elif isinstance(expected, (list, tuple)): + assert type(expected) is type(actual) + for _a, _e in zip(actual, expected): + _assert_close(_a, _e) + elif isinstance(expected, int): + assert type(expected) is type(actual) + assert expected == actual + else: + raise ValueError(f"Unsupported type {type(expected)}") + + +def _test_compile_fwd_bwd( + fn, + inp, + device: torch.device, + unpack_inp: bool = False, + backend: str = "inductor", + fullgraph: bool = True, + skip_backward: bool = False, + *args, + **kwargs, +): + eager_input = _copy_input_tensors(inp, device) + compile_input = _copy_input_tensors(inp, device) + + if unpack_inp: + eager_out = fn(*eager_input, *args, **kwargs) + else: + eager_out = fn(eager_input, *args, **kwargs) + + if not skip_backward: + eager_loss = reduce_to_scalar_loss(eager_out) + eager_loss.backward() + eager_bwd_out = _grad_detach_clone(eager_input) + + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + if unpack_inp: + compile_out = torch.compile(fn, backend=backend, fullgraph=fullgraph)( + *compile_input + ) + else: + compile_out = torch.compile(fn, backend=backend, fullgraph=fullgraph)( + compile_input + ) + + if not skip_backward: + reduce_to_scalar_loss(compile_out).backward() + compile_bwd_out = _grad_detach_clone(compile_input) + + _assert_close(compile_out, eager_out) + if not skip_backward: + _assert_close(compile_bwd_out, eager_bwd_out) + + +class TestPt2(unittest.TestCase): + def setUp(self): + super().setUp() + register_fake_classes() + + def tearDown(self): + deregister_fake_classes() + super().tearDown() + + def _test_kjt_input_module( + self, + kjt_input_module: torch.nn.Module, + kjt: KeyedJaggedTensor, + inputs: Tuple[Any], + test_dynamo: bool = True, + test_aot_inductor: bool = True, + test_pt2_ir_export: bool = False, + ) -> None: + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys()) + em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs) + eager_output = EM(*em_inputs) + if test_dynamo: + x = torch._dynamo.export(EM, same_signature=True)(*em_inputs) + + export_gm = x.graph_module + export_gm_output = export_gm(*em_inputs) + + assert_close(eager_output, export_gm_output) + + if test_aot_inductor: + # pyre-ignore + so_path: str = AOTIRunnerUtil.compile( + EM, + inputs, + ) + device = "cuda" + # pyre-ignore + aot_inductor_module = AOTIRunnerUtil.load(device, so_path) + aot_actual_output = aot_inductor_module(*em_inputs) + assert_close(eager_output, aot_actual_output) + + if test_pt2_ir_export: + symint_wrapper = KJTInputExportDynamicShapeWrapper(EM) + + # KJTInputExportDynamicShapeWrapper represents sizes of values/weights + # from first element of values/weights respectively (simulate symint) + # Need to set as size in order to run a proper forward + em_inputs[0][0] = kjt.values().size(0) + em_inputs[2][0] = kjt.weights().size(0) + + if not kjt.values().is_meta: + eager_output = symint_wrapper(*em_inputs) + + pt2_ir = torch.export.export( + symint_wrapper, em_inputs, {}, strict=False + ) + + pt2_ir_output = pt2_ir.module()(*em_inputs) + assert_close(eager_output, pt2_ir_output) + + # Separate test for Dynamo, as it fallbacks on VB path. + # Torchrec has lazy init modules, depending on the first input => we need to run eager with tracing inputs. + # But other test cases do not need to go VB. + def _test_kjt_input_module_dynamo_compile( + self, + kjt_input_module: torch.nn.Module, + kjt_keys: List[str], + # pyre-ignore + inputs, + backend: str = "eager", + ) -> None: + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + EM: torch.nn.Module = KJTInputExportWrapperWithStrides( + kjt_input_module, kjt_keys + ) + eager_output = EM(*inputs) + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + dynamo_eager_out = torch.compile(EM, backend=backend, fullgraph=True)( + *inputs + ) + assert_close(eager_output, dynamo_eager_out) + + @given( + test_type_backend=st.sampled_from( + [(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")] + ) + ) + @settings(deadline=None) + def test_kjt_split(self, test_type_backend: Tuple[_TestType, str]) -> None: + test_type, backend = test_type_backend + + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.split([1, 2, 1]) + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + if test_type == _TestType.EXPORT: + self._test_kjt_input_module( + M(), + kjt, + (), + test_aot_inductor=False, + test_dynamo=False, + test_pt2_ir_export=True, + ) + elif test_type == _TestType.DYNAMO_COMPILE: + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)), + backend=backend, + ) + + @given( + test_type_backend=st.sampled_from( + [(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")] + ) + ) + @settings(deadline=None) + def test_kjt_permute(self, test_type_backend: Tuple[_TestType, str]) -> None: + test_type, backend = test_type_backend + + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): + return kjt.permute(indices) + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + indices: List[int] = [1, 0, 3, 2] + + if test_type == _TestType.EXPORT: + self._test_kjt_input_module( + M(), + kjt, + (indices,), + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + elif test_type == _TestType.DYNAMO_COMPILE: + + def inputs_fn(kjt): + return *kjt_module_kjt_inputs_with_strides(kjt), indices + + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + inputs_fn(kjt_for_pt2_tracing(kjt)), + backend=backend, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_kt_regroup_as_dict( + self, + ) -> None: + + class M(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + keys = ["group_0", "group_1"] + return KeyedTensor.regroup_as_dict(inputs, groups, keys) + + m = M() + + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + + for t in itertools.chain(tensor_list_1, tensor_list_2): + torch._dynamo.decorators.mark_dynamic(t, 0) + torch._dynamo.decorators.mark_dynamic(t, 1) + + eager_output = m(inputs) + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch_compile_backend = "eager" + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + opt_fn = torch.compile( + m, backend=torch_compile_backend, fullgraph=True, dynamic=True + ) + compile_output = opt_fn(inputs) + torch.testing.assert_close(eager_output, compile_output) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPU", + ) + def test_kjt_permute_dynamo_compile(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): + return kjt.permute(indices) + + device = "cuda" + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1], device=device) + indices: List[int] = [1, 0, 3, 2] + # pyre-ignore + inputs_fn = lambda kjt: ( + *kjt_module_kjt_inputs_with_strides(kjt), + indices, + ) + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + inputs_fn(kjt_for_pt2_tracing(kjt)), + backend="inductor", + ) + + def test_kjt_length_per_key(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.length_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module( + M(), + kjt, + (), + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + + def test_kjt_length_per_key_meta(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.length_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + kjt = kjt.to("meta") + + # calling forward on meta inputs once traced should error out + # as calculating length_per_key requires a .tolist() call of lengths + self.assertRaisesRegex( + RuntimeError, + r".*Tensor\.item\(\) cannot be called on meta tensors.*", + lambda: self._test_kjt_input_module( + M(), + kjt, + (), + test_aot_inductor=False, + test_pt2_ir_export=True, + ), + ) + + def test_kjt_offset_per_key(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.offset_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module( + M(), + kjt, + (), + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + + @given( + test_type_backend=st.sampled_from( + [(_TestType.EXPORT, ""), (_TestType.DYNAMO_COMPILE, "aot_eager")] + ) + ) + @settings(deadline=None) + def test_kjt__getitem__(self, test_type_backend: Tuple[_TestType, str]) -> None: + test_type, backend = test_type_backend + + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + out0 = kjt["key0"] + out1 = kjt["key1"] + + return out0, out1 + + # First element represents symint for values and weights shape + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + if test_type == _TestType.EXPORT: + self._test_kjt_input_module( + M(), + kjt, + (), + test_dynamo=False, + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + elif test_type == _TestType.DYNAMO_COMPILE: + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)), + backend=backend, + ) + + def test_kjt_to_dict_with_strides_dynamo(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + return kjt.to_dict() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + kjt_module_kjt_inputs_with_strides(kjt_for_pt2_tracing(kjt)), + ) + + # pyre-ignores + @unittest.skipIf( + True or torch.cuda.device_count() <= 1, + "Test fails all the time, skip it for now\n Not enough GPUs available", + ) + def test_sharded_quant_ebc_dynamo_export_aot_inductor(self) -> None: + sharded_model, input_kjts = _sharded_quant_ebc_model() + kjt = input_kjts[0] + sharded_model(kjt.values(), kjt.lengths()) + + model: torch.nn.Module = sharded_model + model.training = False + replace_registered_tbes_with_mock_tbes(model) + replace_sharded_quant_modules_tbes_with_mock_tbes(model) + + example_inputs = (kjt.values(), kjt.lengths()) + + # pyre-ignore + def kjt_to_inputs(kjt): + return (kjt.values(), kjt.lengths()) + + expected_outputs = [model(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]] + + device: str = "cuda" + + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + tracing_values = kjt.values() + tracing_lengths = kjt.lengths() + torch._dynamo.mark_dynamic(tracing_values, 0) + dynamo_gm, guard = torch._dynamo.export(model, same_signature=False)( + tracing_values, tracing_lengths + ) + dynamo_gm.print_readable() + dynamo_actual_outputs = [ # noqa + dynamo_gm(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:] + ] + # TODO(ivankobzarev): Why dynamo outputs are different than expected, but aot outputs are correct. + # assert_close(expected_outputs, dynamo_actual_outputs) + + # pyre-ignore + so_path: str = AOTIRunnerUtil.compile( + model, + example_inputs, + ) + # pyre-ignore + aot_inductor_module = AOTIRunnerUtil.load(device, so_path) + aot_inductor_module(*example_inputs) + + aot_actual_outputs = [ + aot_inductor_module(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:] + ] + assert_close(expected_outputs, aot_actual_outputs) + + def test_sharded_quant_ebc_non_strict_export(self) -> None: + sharded_model, input_kjts = _sharded_quant_ebc_model( + local_device="cpu", compute_device="cpu" + ) + kjt = input_kjts[0] + sharded_model(kjt.values(), kjt.lengths()) + + from torch.export import _trace + + ep = _trace._export( + sharded_model, + ( + kjt.values(), + kjt.lengths(), + ), + {}, + strict=False, + pre_dispatch=True, + ) + + ep.module()(kjt.values(), kjt.lengths()) + + # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices) + # ensure such node isn't present, as it causes issues with IR + for n in ep.graph_module.graph.nodes: + self.assertFalse("auto_functionalized" in str(n.name)) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_sharded_quant_fpebc_non_strict_export(self) -> None: + sharded_model, input_kjts = _sharded_quant_ebc_model( + local_device="cpu", compute_device="cpu", feature_processor=True + ) + kjt = input_kjts[0] + + sharded_model(kjt.values(), kjt.lengths()) + + from torch.export import _trace + + ep = _trace._export( + sharded_model, + ( + kjt.values(), + kjt.lengths(), + ), + {}, + strict=False, + pre_dispatch=True, + ) + ep.module()(kjt.values(), kjt.lengths()) + + # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices) + # ensure such node isn't present, as it causes issues with IR + for n in ep.graph_module.graph.nodes: + self.assertFalse("auto_functionalized" in str(n.name)) + + def test_maybe_compute_kjt_to_jt_dict(self) -> None: + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + self._test_kjt_input_module( + ComputeKJTToJTDict(), + kjt, + (), + # TODO: turn on AOT Inductor test once the support is ready + test_aot_inductor=False, + ) + + def test_kjt_values_specialization(self): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + from torch._dynamo.testing import CompileCounter + + kjt0 = KeyedJaggedTensor( + values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 1, 1, 2, 2]), + stride=2, + ) + torch._dynamo.decorators.mark_unbacked(kjt0._values, 0) + + counter = CompileCounter() + + @torch._dynamo.optimize(counter, nopython=True) + def f(kjt): + l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1]) + return l[0].values().sum() + l[1].values().sum() + l[2].values().sum() + + f(kjt0) + self.assertEqual(counter.frame_count, 1) + + kjt1 = KeyedJaggedTensor( + values=torch.tensor([], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 0, 0, 0, 0]), + stride=2, + ) + torch._dynamo.decorators.mark_unbacked(kjt1._values, 0) + f(kjt1) + self.assertEqual(counter.frame_count, 1) + + def test_kjt_values_specialization_utils(self): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + from torch._dynamo.testing import CompileCounter + + kjt0 = KeyedJaggedTensor( + values=torch.tensor([3, 4, 5, 6, 7, 8], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 1, 1, 2, 2]), + stride=2, + ).sync() + + counter = CompileCounter() + + @torch._dynamo.optimize(counter, nopython=True) + def f(kjt): + l: List[KeyedJaggedTensor] = kjt.split([1, 1, 1]) + return l[0].values().sum() + l[1].values().sum() + l[2].values().sum() + + f(kjt_for_pt2_tracing(kjt0)) + self.assertEqual(counter.frame_count, 1) + + kjt1 = KeyedJaggedTensor( + values=torch.tensor([], dtype=torch.int64), + keys=["f0", "f1", "f2"], + lengths=torch.tensor([0, 0, 0, 0, 0, 0]), + stride=2, + ).sync() + f(kjt_for_pt2_tracing(kjt1)) + self.assertEqual(counter.frame_count, 1) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_ebc_vb_reindex(self) -> None: + device = "cuda" + + def fn( + embs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], + ): + reindex_output = torch.ops.fbgemm.batch_index_select_dim0_tensor( + inputs=embs, + indices=indices.view(-1), + input_num_indices=torch.tensor(input_num_indices, dtype=torch.int64), + input_rows=torch.tensor(input_rows, dtype=torch.int64), + input_columns=torch.tensor(input_columns, dtype=torch.int64), + permute_output_dim_0_1=True, + ) + return reindex_output + + N = 5 + batch_size = 10 + emb_dim = 12 + embs: torch.Tensor = torch.randn( + [N * batch_size * emb_dim], device=device, requires_grad=True + ) + torch._dynamo.mark_dynamic(embs, 0) + input_num_indices = [batch_size] * N + input_rows = [batch_size] * N + input_columns = [emb_dim] * N + indices: torch.Tensor = ( + torch.arange(batch_size) + .expand(N, batch_size) + .contiguous() + .to(device=device) + ) + + ins = (embs, indices, input_num_indices, input_rows, input_columns) + _test_compile_fwd_bwd(fn, ins, device, unpack_inp=True) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_permute_pooled_embs(self) -> None: + device = "cuda" + m = PermutePooledEmbeddings( + embs_dims=[12, 12, 12], + permute=[2, 1, 0], + ) + inp = torch.randn(12, 3) + _test_compile_fwd_bwd(m, inp, device, backend="aot_eager") + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_permute_pooled_embs_split(self) -> None: + device = "cuda" + m = PermutePooledEmbeddingsSplit( + embs_dims=[12, 12, 12], + permute=[2, 1, 0], + ) + inp = torch.randn(12, 3) + _test_compile_fwd_bwd(m, inp, device) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_permute_multi_embedding(self) -> None: + device = "cuda" + batch_size = 16 + + def func(values, permutes, in_shapes, out_shapes, out_lengths): + return torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths.tolist() + ) + + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + for embs in values: + torch._dynamo.mark_dynamic(embs, 0) + torch._dynamo.mark_dynamic(embs, 1) + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32) + inp = (values, permutes, in_shapes, out_shapes, out_lengths) + _test_compile_fwd_bwd(func, inp, device, unpack_inp=True) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPU", + ) + def test_tbe_compile(self) -> None: + D = 4 + T = 2 + E = 10 + Ds = [D] * T + Es = [E] * T + + device = "cuda" + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + ( + EmbeddingLocation.MANAGED + if device == "cuda" + else EmbeddingLocation.HOST + ), + ComputeDevice.CUDA if device == "cuda" else ComputeDevice.CPU, + ) + for (E, D) in zip(Es, Ds) + ], + ) + tbe.init_embedding_weights_uniform(0, 1) + + class M(torch.nn.Module): + def __init__(self, tbe) -> None: + super().__init__() + self.tbe = tbe + + def forward(self, indices, offsets, f) -> torch.Tensor: + e = self.tbe(indices, offsets) + return torch.mul(torch.mean(e, dim=1), f) + + m = M(tbe) + m.train(True) + m_compile = copy.deepcopy(m) + m_compile.train(True) + + def get_weights(m): + return m.tbe.weights_uvm.clone().detach() + + original_weights = get_weights(m) + + x = torch.Tensor( + [ + [ + [1], + [1], + ], + [[3], [4]], + ] + ).to(dtype=torch.int64, device=device) + (indices, offsets) = get_table_batched_offsets_from_dense( + x, use_cpu=device == "cpu" + ) + inp_f = torch.randn(T, requires_grad=True, device=device) + + # EAGER + out = m(indices, offsets, inp_f.clone()) + reduce_to_scalar_loss(out).backward() + eager_weights_diff = get_weights(m) - original_weights + + # COMPILE + orig_compile_weights = get_weights(m_compile) + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + compile_out = torch.compile(m_compile, backend="aot_eager", fullgraph=True)( + indices, offsets, inp_f.clone() + ) + reduce_to_scalar_loss(compile_out).backward() + compile_weights_diff = get_weights(m_compile) - orig_compile_weights + + assert_close(eager_weights_diff, compile_weights_diff) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPU", + ) + def test_tbe_compile_vb(self) -> None: + D = 4 + T = 2 + E = 10 + Ds = [D] * T + Es = [E] * T + + device = "cuda" + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + ( + EmbeddingLocation.MANAGED + if device == "cuda" + else EmbeddingLocation.HOST + ), + ComputeDevice.CUDA if device == "cuda" else ComputeDevice.CPU, + ) + for (E, D) in zip(Es, Ds) + ], + ) + tbe.init_embedding_weights_uniform(0, 1) + + class M(torch.nn.Module): + def __init__(self, tbe) -> None: + super().__init__() + self.tbe = tbe + + def forward( + self, indices, offsets, batch_size_per_feature_per_rank, f + ) -> torch.Tensor: + e = self.tbe( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + return torch.mul(torch.mean(e, dim=0), f) + + m = M(tbe) + m.train(True) + m_compile = copy.deepcopy(m) + m_compile.train(True) + + def get_weights(m): + return m.tbe.weights_uvm.clone().detach() + + original_weights = get_weights(m) + + indices = torch.Tensor([1, 2, 0, 1, 2]).to(dtype=torch.int64, device=device) + lengths = torch.Tensor([2, 3]).to(dtype=torch.int64, device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + batch_size_per_feature_per_rank = [[1], [2]] + inp_f = torch.randn(1, requires_grad=True, device=device) + + # EAGER + out = m(indices, offsets, batch_size_per_feature_per_rank, inp_f.clone()) + reduce_to_scalar_loss(out).backward() + eager_weights_diff = get_weights(m) - original_weights + + # COMPILE + orig_compile_weights = get_weights(m_compile) + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + compile_out = torch.compile(m_compile, backend="aot_eager", fullgraph=True)( + indices, offsets, batch_size_per_feature_per_rank, inp_f.clone() + ) + reduce_to_scalar_loss(compile_out).backward() + compile_weights_diff = get_weights(m_compile) - orig_compile_weights + + assert_close(eager_weights_diff, compile_weights_diff) diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py new file mode 100644 index 000000000..4835f22e8 --- /dev/null +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import timeit +import unittest +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import click +import fbgemm_gpu.sparse_ops # noqa: F401, E402 +import torch +import torchrec +import torchrec.pt2.checks +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) +from hypothesis import given, settings, strategies as st, Verbosity +from torch import distributed as dist +from torch._dynamo.testing import reduce_to_scalar_loss +from torch.distributed import ProcessGroup +from torch.testing._internal.distributed.fake_pg import FakeStore +from torchrec.distributed.embedding import EmbeddingCollectionSharder +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig + +from torchrec.distributed.model_parallel import DistributedModelParallel + +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.planner.types import ShardingPlan +from torchrec.distributed.sharding_plan import EmbeddingBagCollectionSharder + +from torchrec.distributed.test_utils.infer_utils import TestModelInfo + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.types import ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingBagConfig, + EmbeddingCollection, +) +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.pt2.utils import kjt_for_pt2_tracing +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") + + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") +except OSError: + pass + + +class NoOpFPC(FeatureProcessorsCollection): + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedJaggedTensor: + return features + + +class _ModelType(Enum): + EBC = 1 + EC = 2 + FPEBC = 3 + + +class _InputType(Enum): + SINGLE_BATCH = 1 + VARIABLE_BATCH = 2 + + +class _ConvertToVariableBatch(Enum): + FALSE = 0 + TRUE = 1 + + +@dataclass +class _TestConfig: + n_extra_numerics_checks_inputs: int = 1 + + +class EBCSharderFixedShardingType(EmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomms_config: Optional[QCommsConfig] = None, + ) -> None: + if fused_params is None: + fused_params = {} + if "learning_rate" not in fused_params: + fused_params["learning_rate"] = 0.1 + + self._sharding_type = sharding_type + super().__init__(fused_params=fused_params) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +class ECSharderFixedShardingType(EmbeddingCollectionSharder): + def __init__( + self, + sharding_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomms_config: Optional[QCommsConfig] = None, + ) -> None: + if fused_params is None: + fused_params = {} + if "learning_rate" not in fused_params: + fused_params["learning_rate"] = 0.1 + + self._sharding_type = sharding_type + super().__init__(fused_params=fused_params) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +def _gen_model(test_model_type: _ModelType, mi: TestModelInfo) -> torch.nn.Module: + emb_dim: int = max(t.embedding_dim for t in mi.tables) + if test_model_type == _ModelType.EBC: + + class M_ebc(torch.nn.Module): + def __init__(self, ebc: EmbeddingBagCollection) -> None: + super().__init__() + self._ebc = ebc + self._linear = torch.nn.Linear( + mi.num_float_features, emb_dim, device=mi.dense_device + ) + + def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor: + kt: KeyedTensor = self._ebc(x) + v = kt.values() + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) + + return M_ebc( + EmbeddingBagCollection( + # pyre-ignore + tables=mi.tables, + device=mi.sparse_device, + ) + ) + if test_model_type == _ModelType.FPEBC: + + class M_fpebc(torch.nn.Module): + def __init__(self, fpebc: FeatureProcessedEmbeddingBagCollection) -> None: + super().__init__() + self._fpebc = fpebc + self._linear = torch.nn.Linear( + mi.num_float_features, emb_dim, device=mi.dense_device + ) + + def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor: + kt: KeyedTensor = self._fpebc(x) + v = kt.values() + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) + + return M_fpebc( + FeatureProcessedEmbeddingBagCollection( + embedding_bag_collection=EmbeddingBagCollection( + # pyre-ignore + tables=mi.tables, + device=mi.sparse_device, + is_weighted=True, + ), + feature_processors=NoOpFPC(), + ) + ) + elif test_model_type == _ModelType.EC: + + class M_ec(torch.nn.Module): + def __init__(self, ec: EmbeddingCollection) -> None: + super().__init__() + self._ec = ec + + def forward( + self, x: KeyedJaggedTensor, y: torch.Tensor + ) -> List[JaggedTensor]: + d: Dict[str, JaggedTensor] = self._ec(x) + v = torch.stack(d.values(), dim=0).sum(dim=0) + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) + + return M_ec( + EmbeddingCollection( + # pyre-ignore + tables=mi.tables, + device=mi.sparse_device, + ) + ) + else: + raise RuntimeError(f"Unsupported test_model_type:{test_model_type}") + + +def _test_compile_rank_fn( + test_model_type: _ModelType, + rank: int, + world_size: int, + backend: str, + sharding_type: str, + kernel_type: str, + input_type: _InputType, + convert_to_vb: bool, + config: _TestConfig, + torch_compile_backend: Optional[str] = None, + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + num_embeddings = 256 + # emb_dim must be % 4 == 0 for fbgemm + emb_dim = 12 + batch_size = 10 + num_features: int = 5 + + num_float_features: int = 8 + num_weighted_features: int = 1 + + # pyre-ignore + device: torch.Device = torch.device("cuda") + pg: Optional[dist.ProcessGroup] = ctx.pg + assert pg is not None + + topology: Topology = Topology(world_size=world_size, compute_device="cuda") + mi = TestModelInfo( + dense_device=device, + sparse_device=device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + config_type = ( + EmbeddingBagConfig + if test_model_type == _ModelType.EBC or test_model_type == _ModelType.FPEBC + else EmbeddingConfig + ) + + # pyre-ignore + mi.tables = [ + config_type( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(mi.num_features) + ] + + # pyre-ignore + mi.weighted_tables = [ + config_type( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(mi.num_weighted_features) + ] + + mi.model = _gen_model(test_model_type, mi) + mi.model.training = True + + model = mi.model + + planner = EmbeddingShardingPlanner( + topology=Topology(world_size, device.type, local_world_size=local_size), + constraints=None, + ) + + sharders = [ + EBCSharderFixedShardingType(sharding_type), + ECSharderFixedShardingType(sharding_type), + ] + + plan: ShardingPlan = planner.collective_plan( + model, + # pyre-ignore + sharders, + pg, + ) + + # pyre-ignore + def _dmp(m: torch.nn.Module) -> DistributedModelParallel: + return DistributedModelParallel( + m, + # pyre-ignore + env=ShardingEnv.from_process_group(pg), + plan=plan, + sharders=sharders, + device=device, + init_data_parallel=False, + ) + + dmp = _dmp(model) + dmp_compile = _dmp(model) + + # TODO: Fix some data dependent failures on subsequent inputs + n_extra_numerics_checks = config.n_extra_numerics_checks_inputs + ins = [] + + for _ in range(1 + n_extra_numerics_checks): + if input_type == _InputType.VARIABLE_BATCH: + ( + _, + local_model_inputs, + ) = ModelInput.generate_variable_batch_input( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=mi.tables, + ) + else: + ( + _, + local_model_inputs, + ) = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=mi.tables, + weighted_tables=mi.weighted_tables, + variable_batch_size=False, + ) + ins.append(local_model_inputs) + + local_model_input = ins[0][rank].to(device) + + kjt = local_model_input.idlist_features + ff = local_model_input.float_features + ff.requires_grad = True + kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb) + + compile_input_ff = ff.clone().detach() + compile_input_ff.requires_grad = True + + torchrec.distributed.comm_ops.set_use_sync_collectives(True) + torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True) + + dmp.train(True) + dmp_compile.train(True) + + def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: + tbe = None + if test_model_type == _ModelType.EBC: + tbe = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_lookups`. + dmp._dmp_wrapped_module._ebc._lookups[0] + ._emb_modules[0] + ._emb_module + ) + elif test_model_type == _ModelType.FPEBC: + tbe = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_lookups`. + dmp._dmp_wrapped_module._fpebc._lookups[0] + ._emb_modules[0] + ._emb_module + ) + elif test_model_type == _ModelType.EC: + tbe = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `_lookups`. + dmp._dmp_wrapped_module._ec._lookups[0] + ._emb_modules[0] + ._emb_module + ) + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a + # function. + return tbe.weights_dev.clone().detach() + + original_weights = get_weights(dmp) + original_weights.zero_() + original_compile_weights = get_weights(dmp_compile) + original_compile_weights.zero_() + + eager_out = dmp(kjt_ft, ff) + + reduce_to_scalar_loss(eager_out).backward() + eager_weights_diff = get_weights(dmp) - original_weights + + if torch_compile_backend is None: + return + + ##### COMPILE ##### + run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"] + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( + True + ) + + opt_fn = torch.compile( + dmp_compile, + backend=torch_compile_backend, + fullgraph=True, + ) + compile_out = opt_fn( + kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb), compile_input_ff + ) + torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3) + if run_compile_backward: + reduce_to_scalar_loss(compile_out).backward() + compile_weights_diff = ( + get_weights(dmp_compile) - original_compile_weights + ) + # Checking TBE weights updated inplace + torch.testing.assert_close( + eager_weights_diff, compile_weights_diff, atol=1e-3, rtol=1e-3 + ) + # Check float inputs gradients + torch.testing.assert_close( + ff.grad, compile_input_ff.grad, atol=1e-3, rtol=1e-3 + ) + + ##### COMPILE END ##### + + ##### NUMERIC CHECK ##### + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + for i in range(n_extra_numerics_checks): + local_model_input = ins[1 + i][rank].to(device) + kjt = local_model_input.idlist_features + kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb) + ff = local_model_input.float_features + ff.requires_grad = True + eager_out_i = dmp(kjt_ft, ff) + reduce_to_scalar_loss(eager_out_i).backward() + eager_weights_diff = get_weights(dmp) - original_weights + + compile_input_ff = ff.clone().detach() + compile_input_ff.requires_grad = True + compile_out_i = opt_fn(kjt_ft, compile_input_ff) + torch.testing.assert_close( + eager_out_i, compile_out_i, atol=1e-3, rtol=1e-3 + ) + if run_compile_backward: + torch._dynamo.testing.reduce_to_scalar_loss( + compile_out_i + ).backward() + compile_weights_diff = ( + get_weights(dmp_compile) - original_compile_weights + ) + # Checking TBE weights updated inplace + torch.testing.assert_close( + eager_weights_diff, + compile_weights_diff, + atol=1e-3, + rtol=1e-3, + ) + # Check float inputs gradients + torch.testing.assert_close( + ff.grad, compile_input_ff.grad, atol=1e-3, rtol=1e-3 + ) + + ##### NUMERIC CHECK END ##### + + +def _test_compile_fake_pg_fn( + rank: int, + world_size: int, + num_features: int = 5, + batch_size: int = 10, + num_embeddings: int = 256, +) -> None: + sharding_type = ShardingType.TABLE_WISE.value + torch_compile_backend = "eager" + config = _TestConfig() + # emb_dim must be % 4 == 0 for fbgemm + emb_dim = 12 + + num_float_features: int = 8 + num_weighted_features: int = 1 + + device: torch.Device = torch.device("cuda") + store = FakeStore() + dist.init_process_group(backend="fake", rank=rank, world_size=2, store=store) + pg: ProcessGroup = dist.distributed_c10d._get_default_group() + + topology: Topology = Topology(world_size=world_size, compute_device="cuda") + mi = TestModelInfo( + dense_device=device, + sparse_device=device, + num_features=num_features, + num_float_features=num_float_features, + num_weighted_features=num_weighted_features, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(mi.num_features) + ] + + mi.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(mi.num_weighted_features) + ] + + mi.model = _gen_model(_ModelType.EBC, mi) + mi.model.training = True + + model = mi.model + + planner = EmbeddingShardingPlanner( + topology=Topology(world_size, device.type), + constraints=None, + ) + + sharders = [ + EBCSharderFixedShardingType(sharding_type), + ECSharderFixedShardingType(sharding_type), + ] + + plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore + + def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore + return DistributedModelParallel( + m, + env=ShardingEnv(world_size, rank, pg), + plan=plan, + sharders=sharders, + device=device, + init_data_parallel=False, + ) + + dmp = _dmp(model) + dmp_compile = _dmp(model) + + # TODO: Fix some data dependent failures on subsequent inputs + n_extra_numerics_checks = config.n_extra_numerics_checks_inputs + ins = [] + + for _ in range(1 + n_extra_numerics_checks): + ( + _, + local_model_inputs, + ) = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=mi.tables, + weighted_tables=mi.weighted_tables, + variable_batch_size=False, + ) + ins.append(local_model_inputs) + + local_model_input = ins[0][rank].to(device) + + kjt = local_model_input.idlist_features + ff = local_model_input.float_features + ff.requires_grad = True + kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True) + + compile_input_ff = ff.clone().detach() + compile_input_ff.requires_grad = True + + torchrec.distributed.comm_ops.set_use_sync_collectives(True) + torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True) + + dmp.train(True) + dmp_compile.train(True) + + def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_lookups`. + tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a + # function. + return tbe.weights_dev.clone().detach() + + original_weights = get_weights(dmp) + original_weights.zero_() + original_compile_weights = get_weights(dmp_compile) + original_compile_weights.zero_() + + eager_out = dmp(kjt_ft, ff) + reduce_to_scalar_loss(eager_out).backward() + + ##### COMPILE ##### + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True + + opt_fn = torch.compile( + dmp_compile, + backend=torch_compile_backend, + fullgraph=True, + ) + compile_out = opt_fn( + kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff + ) + torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3) + ##### COMPILE END ##### + + +class TestPt2Train(MultiProcessTestBase): + def disable_cuda_tf32(self) -> bool: + return True + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ], + ), + given_config_tuple=st.sampled_from( + [ + ( + _ModelType.EBC, + ShardingType.TABLE_WISE.value, + _InputType.SINGLE_BATCH, + _ConvertToVariableBatch.TRUE, + "inductor", + _TestConfig(), + ), + ( + _ModelType.EBC, + ShardingType.COLUMN_WISE.value, + _InputType.SINGLE_BATCH, + _ConvertToVariableBatch.TRUE, + "inductor", + _TestConfig(), + ), + ( + _ModelType.EBC, + ShardingType.TABLE_WISE.value, + _InputType.SINGLE_BATCH, + _ConvertToVariableBatch.FALSE, + "inductor", + _TestConfig(), + ), + ( + _ModelType.EBC, + ShardingType.COLUMN_WISE.value, + _InputType.SINGLE_BATCH, + _ConvertToVariableBatch.FALSE, + "inductor", + _TestConfig(), + ), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_compile_multiprocess( + self, + kernel_type: str, + given_config_tuple: Tuple[ + _ModelType, + str, + _InputType, + _ConvertToVariableBatch, + Optional[str], + _TestConfig, + ], + ) -> None: + model_type, sharding_type, input_type, tovb, compile_backend, config = ( + given_config_tuple + ) + self._run_multi_process_test( + callable=_test_compile_rank_fn, + test_model_type=model_type, + world_size=2, + backend="nccl", + sharding_type=sharding_type, + kernel_type=kernel_type, + input_type=input_type, + convert_to_vb=tovb == _ConvertToVariableBatch.TRUE, + config=config, + torch_compile_backend=compile_backend, + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires one GPU", + ) + def test_compile_multiprocess_fake_pg( + self, + ) -> None: + _test_compile_fake_pg_fn( + rank=0, + world_size=2, + ) + + +@click.command() +@click.option( + "--repeat", + type=int, + default=1, + help="repeat times", +) +@click.option( + "--rank", + type=int, + default=0, + help="rank in the test", +) +@click.option( + "--world-size", + type=int, + default=2, + help="world_size in the test", +) +@click.option( + "--num-features", + type=int, + default=5, + help="num_features in the test", +) +@click.option( + "--batch-size", + type=int, + default=10, + help="batch_size in the test", +) +def compile_benchmark( + rank: int, world_size: int, num_features: int, batch_size: int, repeat: int +) -> None: + run: str = ( + f"_test_compile_fake_pg_fn(rank={rank}, world_size={world_size}, " + f"num_features={num_features}, batch_size={batch_size})" + ) + print("*" * 20 + " compile_benchmark started " + "*" * 20) + t = timeit.timeit(stmt=run, number=repeat, globals=globals()) + print("*" * 20 + " compile_benchmark completed " + "*" * 20) + print( + f"rank: {rank}, world_size: {world_size}, " + f"num_features: {num_features}, batch_size: {batch_size}, time: {t:.2f}s" + ) + + +if __name__ == "__main__": + compile_benchmark() diff --git a/torchrec/distributed/tests/test_qcomms_embedding_modules.py b/torchrec/distributed/tests/test_qcomms_embedding_modules.py new file mode 100644 index 000000000..919c2870d --- /dev/null +++ b/torchrec/distributed/tests/test_qcomms_embedding_modules.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import unittest +from typing import Any, Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn +import torchrec.distributed as trec_dist +from hypothesis import given, settings, Verbosity +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fbgemm_qcomm_codec import ( + CommType, + get_qcomm_codecs_registry, + QCommsConfig, +) + +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + ParameterShardingGenerator, + row_wise, + table_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.types import ModuleSharder, ParameterSharding, ShardingEnv +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +def _test_sharding( + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + backend: str, + parameter_sharding_plan: Dict[str, ParameterSharding], + sharder: ModuleSharder[nn.Module], + local_size: Optional[int] = None, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + + model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + apply_optimizer_in_backward( + torch.optim.SGD, + model.parameters(), + {"lr": 1.0}, + ) + + unsharded_model = model + sharded_model = sharder.shard( + module=model, + params=parameter_sharding_plan, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + device=ctx.device, + ) + + unsharded_model.load_state_dict(copy.deepcopy(initial_state_dict)) + copy_state_dict(sharded_model.state_dict(), copy.deepcopy(initial_state_dict)) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + # each rank gets a subbatch + sharded_model_pred_kt = sharded_model(kjt_input_per_rank[ctx.rank]).to_dict() + _sharded_model_pred = torch.stack( # noqa + [sharded_model_pred_kt[feature] for feature in feature_keys] + ) + + for _it in range(1): + unsharded_model_pred_kt = [] + for rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_pred_kt.append( + unsharded_model(kjt_input_per_rank[rank]) + ) + + all_unsharded_preds = [] + for rank in range(ctx.world_size): + unsharded_model_pred_kt_mini_batch = unsharded_model_pred_kt[ + rank + ].to_dict() + + all_unsharded_preds.extend( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + if rank == ctx.rank: + unsharded_model_pred = torch.stack( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + + # sharded model + # each rank gets a subbatch + sharded_model_pred_kt = sharded_model( + kjt_input_per_rank[ctx.rank] + ).to_dict() + sharded_model_pred = torch.stack( + [sharded_model_pred_kt[feature] for feature in feature_keys] + ) + + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_model_pred.cpu(), + unsharded_model_pred.cpu(), + ) + + sharded_model_pred.sum().backward() + + all_unsharded_preds = torch.stack(all_unsharded_preds) + all_unsharded_preds.sum().backward() + + +@skip_if_asan_class +class ConstructParameterShardingTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + per_param_sharding=st.sampled_from( + [ + { + "0": table_wise(rank=0), + "1": row_wise(), + "2": column_wise(ranks=[0, 1]), + }, + ] + ), + qcomms_config=st.sampled_from( + [ + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.FP32 + ), + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_parameter_sharding_ebc( + self, + per_param_sharding: Dict[str, ParameterShardingGenerator], + qcomms_config: QCommsConfig, + ) -> None: + + WORLD_SIZE = 2 + EMBEDDING_DIM = 8 + NUM_EMBEDDINGS = 4 + + embedding_bag_config = [ + EmbeddingBagConfig( + name=str(idx), + feature_names=[f"feature_{idx}"], + embedding_dim=EMBEDDING_DIM, + num_embeddings=NUM_EMBEDDINGS, + ) + for idx in per_param_sharding + ] + + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] [] [2] + # "feature_1" [2] [2,3] [] + # "feature_2" [0,1,2,3] [0,2] [2,3] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0, 1,2,3] + # "feature_1" [2,3] None [2] + # "feature_2" [0, 1] None [2] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor([0, 1, 2, 2, 2, 3, 0, 1, 2, 3, 0, 2, 2, 3]), + lengths=torch.LongTensor([2, 0, 1, 1, 2, 0, 4, 2, 2]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor([3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2, 0, 1, 2]), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1, 2, 0, 1]), + ), + ] + + sharder = EmbeddingBagCollectionSharder( + qcomm_codecs_registry=( + get_qcomm_codecs_registry(qcomms_config) + if qcomms_config is not None + else None + ) + ) + + ebc = EmbeddingBagCollection(tables=embedding_bag_config) + apply_optimizer_in_backward( + torch.optim.SGD, + ebc.parameters(), + {"lr": 1.0}, + ) + + parameter_sharding_plan = construct_module_sharding_plan( + module=ebc, + per_param_sharding=per_param_sharding, + local_size=2, + world_size=2, + # pyre-ignore + sharder=sharder, + ) + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + initial_state_dict={ + "embedding_bags.0.weight": torch.Tensor( + [[1] * EMBEDDING_DIM for val in range(NUM_EMBEDDINGS)] + ), + "embedding_bags.1.weight": torch.Tensor( + [[2] * EMBEDDING_DIM for val in range(NUM_EMBEDDINGS)] + ), + "embedding_bags.2.weight": torch.Tensor( + [[3] * EMBEDDING_DIM for val in range(NUM_EMBEDDINGS)] + ), + }, + kjt_input_per_rank=kjt_input_per_rank, + backend=( + "nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo" + ), + sharder=sharder, + parameter_sharding_plan=parameter_sharding_plan, + ) diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index 3e6615b62..131b9d3d6 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -5,13 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from typing import Any, cast, Dict, List, Optional +from typing import cast, Dict, Optional, Tuple import hypothesis.strategies as st import torch from hypothesis import given, settings, Verbosity -from torch import nn, quantization as quant +from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ModuleSharder from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology @@ -20,63 +22,19 @@ EmbeddingPerfEstimator, EmbeddingStorageEstimator, ) -from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder -from torchrec.distributed.shard import shard +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.infer_utils import quantize, TestQuantEBCSharder from torchrec.distributed.test_utils.test_model import ( _get_default_rtol_and_atol, ModelInput, TestSparseNN, ) from torchrec.distributed.types import ShardedModule, ShardingEnv, ShardingType -from torchrec.inference.modules import copy_to_device +from torchrec.distributed.utils import copy_to_device from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.quant.embedding_modules import ( - EmbeddingBagCollection as QuantEmbeddingBagCollection, -) from torchrec.types import CopyMixIn -class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder): - def __init__( - self, - sharding_type: str, - kernel_type: str, - fused_params: Optional[Dict[str, Any]] = None, - shardable_params: Optional[List[str]] = None, - ) -> None: - super().__init__(fused_params=fused_params, shardable_params=shardable_params) - self._sharding_type = sharding_type - self._kernel_type = kernel_type - - def sharding_types(self, compute_device_type: str) -> List[str]: - return [self._sharding_type] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [self._kernel_type] - - -def _quantize( - module: nn.Module, inplace: bool, output_type: torch.dtype = torch.float -) -> nn.Module: - qconfig = quant.QConfig( - activation=quant.PlaceholderObserver.with_args(dtype=output_type), - weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8), - ) - return quant.quantize_dynamic( - module, - qconfig_spec={ - EmbeddingBagCollection: qconfig, - }, - mapping={ - EmbeddingBagCollection: QuantEmbeddingBagCollection, - }, - inplace=inplace, - ) - - class CopyModule(nn.Module, CopyMixIn): def __init__(self) -> None: super().__init__() @@ -126,19 +84,22 @@ def _buffer_param_check( module_copy: nn.Module, device: torch.device, device_copy: torch.device, + recurse: bool = True, ) -> None: # check all buffer/param under the module is value-identical # but device-different with the copied module. for (name, buffer), (name_copy, buffer_copy) in zip( - list(module.named_buffers()) + list(module.named_parameters()), - list(module_copy.named_buffers()) + list(module_copy.named_parameters()), + list(module.named_buffers(recurse=recurse)) + + list(module.named_parameters(recurse=recurse)), + list(module_copy.named_buffers(recurse=recurse)) + + list(module_copy.named_parameters(recurse=recurse)), ): - self.assertEquals(name, name_copy) + self.assertEqual(name, name_copy) actual, expected = buffer.detach().cpu(), buffer_copy.detach().cpu() rtol, atol = _get_default_rtol_and_atol(actual, expected) torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol) - self.assertEquals(buffer.detach().device, device) - self.assertEquals(buffer_copy.detach().device, device_copy) + self.assertEqual(buffer.detach().device, device) + self.assertEqual(buffer_copy.detach().device, device_copy) def _recursive_device_check( self, @@ -161,6 +122,9 @@ def _recursive_device_check( self.assertTrue(buffer.detach().is_set_to(buffer_copy.detach())) # don't go into named_children of ShardedModule return + self._buffer_param_check( + module, module_copy, device, device_copy, recurse=False + ) for name_child, name_child_copy in zip( module.named_children(), module_copy.named_children() ): @@ -187,9 +151,24 @@ def _recursive_device_check( torch.float, ] ), + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_quant_pred(self, output_type: torch.dtype) -> None: + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_quant_pred( + self, output_type: torch.dtype, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") device_1 = torch.device("cuda:1") model = TestSparseNN( @@ -199,14 +178,19 @@ def test_quant_pred(self, output_type: torch.dtype) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - quant_model = _quantize(model, inplace=True, output_type=output_type) + quant_model = quantize( + model, + inplace=True, + output_type=output_type, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) dmp = DistributedModelParallel( quant_model, sharders=[ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, ), ) @@ -230,9 +214,24 @@ def test_quant_pred(self, output_type: torch.dtype) -> None: torch.float, ] ), + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None: + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_quant_pred_state_dict( + self, output_type: torch.dtype, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") model = TestSparseNN( @@ -242,7 +241,12 @@ def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - quant_model = _quantize(model, inplace=True, output_type=output_type) + quant_model = quantize( + model, + inplace=True, + output_type=output_type, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) model.training = False dmp = DistributedModelParallel( @@ -251,7 +255,7 @@ def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None: cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, ), ) @@ -267,7 +271,7 @@ def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None: cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, ), ) @@ -304,9 +308,24 @@ def test_quant_pred_state_dict(self, output_type: torch.dtype) -> None: torch.float, ] ), + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_quant_pred_shard(self, output_type: torch.dtype) -> None: + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_quant_pred_shard( + self, output_type: torch.dtype, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") device_1 = torch.device("cuda:1") model = TestSparseNN( @@ -316,15 +335,20 @@ def test_quant_pred_shard(self, output_type: torch.dtype) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - quant_model = _quantize(model, inplace=True, output_type=output_type) + quant_model = quantize( + model, + inplace=True, + output_type=output_type, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) - sharded_model, _sharded_params = shard( + sharded_model = _shard_modules( module=quant_model, sharders=[ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, ), ) @@ -333,8 +357,6 @@ def test_quant_pred_shard(self, output_type: torch.dtype) -> None: env=ShardingEnv.from_local(world_size=2, rank=0), ) - sharded_model = sharded_model.to(device) - sharded_model_copy = copy_to_device( module=sharded_model, current_device=device, to_device=device_1 ) @@ -356,7 +378,7 @@ def test_quant_pred_shard(self, output_type: torch.dtype) -> None: sharded_model_copy(local_batch[0].to(device_1)).cpu(), ) - # pyre-fixme[56] + # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", @@ -371,11 +393,11 @@ def test_copy_mixin(self) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - # pyre-ignore [16] + # pyre-fixme[16]: `TestSparseNN` has no attribute `copy_module`. model.copy_module = CopyModule() - # pyre-ignore [16] + # pyre-fixme[16]: `TestSparseNN` has no attribute `no_copy_module`. model.no_copy_module = NoCopyModule() - quant_model = _quantize(model, inplace=True) + quant_model = quantize(model, inplace=True) dmp = DistributedModelParallel( quant_model, sharders=[ @@ -393,9 +415,9 @@ def test_copy_mixin(self) -> None: ) dmp_1 = dmp.copy(device_1) - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `tensor`. self.assertEqual(dmp_1.module.copy_module.tensor.device, device_1) - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `tensor`. self.assertEqual(dmp_1.module.no_copy_module.tensor.device, torch.device("cpu")) @@ -406,7 +428,7 @@ def setUp(self) -> None: self.tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 10000000, + num_embeddings=(i + 1) * 100000, embedding_dim=(i + 1) * 4, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -423,26 +445,64 @@ def setUp(self) -> None: for i in range(num_weighted_features) ] - # pyre-fixme[56] @unittest.skipIf( torch.cuda.device_count() <= 0, "Not enough GPUs available", ) - def test_shard_one_ebc_cuda(self) -> None: + # pyre-fixme[56] + @given( + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), + per_table_weight_dtypes=st.sampled_from( + [ + None, + {"table_0": torch.quint4x2, "table_1": torch.quint8}, + { + "table_0": torch.quint4x2, + "table_1": torch.quint8, + "table_3": torch.quint4x2, + "table_4": torch.int8, + }, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_shard_one_ebc_cuda( + self, + sharding_type_qsplitscalebias: Tuple[str, bool], + per_table_weight_dtypes: Optional[Dict[str, torch.dtype]], + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") + sparse_device = torch.device("meta") model = TestSparseNN( tables=self.tables, weighted_tables=self.weighted_tables, num_float_features=10, dense_device=device, - sparse_device=torch.device("meta"), + sparse_device=sparse_device, + ) + quant_model = quantize( + model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + per_table_weight_dtypes=per_table_weight_dtypes, ) - quant_model = _quantize(model, inplace=True) sharders = [ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[table.name for table in self.tables], ), @@ -462,34 +522,58 @@ def test_shard_one_ebc_cuda(self) -> None: ), ).plan(quant_model, sharders) + sharding_device_type = "cuda" dmp = DistributedModelParallel( quant_model, plan=plan, - device=None, # cpu + device=torch.device(sharding_device_type), env=ShardingEnv.from_local(world_size=1, rank=0), init_data_parallel=False, + init_parameters=False, ) - self.assertTrue( - # pyre-ignore [16] - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + # flake8: noqa:C419 + all( + param.device.type == sharding_device_type + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ebc`. + for param in dmp.module.sparse.ebc.buffers() + ) ) self.assertTrue( + # flake8: noqa:C419 all( - [ - param.device == torch.device("cpu") - # pyre-ignore [16] - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device.type == sparse_device.type + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `weighted_ebc`. + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) - # pyre-fixme[56] @unittest.skipIf( torch.cuda.device_count() <= 0, "Not enough GPUs available", ) - def test_shard_one_ebc_meta(self) -> None: + # pyre-fixme[56] + @given( + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_shard_one_ebc_meta( + self, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") model = TestSparseNN( tables=self.tables, @@ -498,12 +582,16 @@ def test_shard_one_ebc_meta(self) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - quant_model = _quantize(model, inplace=True) + quant_model = quantize( + model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) sharders = [ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[table.name for table in self.tables], ), @@ -526,46 +614,69 @@ def test_shard_one_ebc_meta(self) -> None: dmp = DistributedModelParallel( quant_model, plan=plan, - device=None, # cpu + device=torch.device("cuda"), env=ShardingEnv.from_local(world_size=1, rank=0), init_data_parallel=False, init_parameters=False, ) self.assertTrue( - # pyre-ignore [16] - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + all(param.device == device for param in dmp.module.sparse.ebc.buffers()) ) self.assertTrue( + # flake8: noqa:C419 all( - [ - param.device == torch.device("meta") - # pyre-ignore [16] - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device == torch.device("meta") + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `weighted_ebc`. + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) - # pyre-fixme[56] @unittest.skipIf( torch.cuda.device_count() <= 0, "Not enough GPUs available", ) - def test_shard_all_ebcs(self) -> None: + # pyre-fixme[56] + @given( + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_shard_all_ebcs( + self, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") + sparse_device = torch.device("meta") model = TestSparseNN( tables=self.tables, weighted_tables=self.weighted_tables, num_float_features=10, dense_device=device, - sparse_device=torch.device("meta"), + sparse_device=sparse_device, + ) + quant_model = quantize( + model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, ) - quant_model = _quantize(model, inplace=True) sharders = [ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, ), ) @@ -587,31 +698,54 @@ def test_shard_all_ebcs(self) -> None: dmp = DistributedModelParallel( quant_model, plan=plan, - device=None, # cpu + device=torch.device("cuda"), env=ShardingEnv.from_local(world_size=1, rank=0), init_data_parallel=False, + init_parameters=True, ) self.assertTrue( - # pyre-ignore [16] - all([param.device == device for param in dmp.module.sparse.ebc.buffers()]) + all( + param.device.type == device.type + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ebc`. + for param in dmp.module.sparse.ebc.buffers() + ) ) + # DMP init_parameters == True by default reinits meta parameters on sharding device self.assertTrue( all( - [ - param.device == device - # pyre-ignore [16] - for param in dmp.module.sparse.weighted_ebc.buffers() - ] + param.device.type == device.type + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `weighted_ebc`. + for param in dmp.module.sparse.weighted_ebc.buffers() ) ) - # pyre-fixme[56] @unittest.skipIf( torch.cuda.device_count() <= 0, "Not enough GPUs available", ) - def test_sharder_bad_param_config(self) -> None: + # pyre-fixme[56] + @given( + sharding_type_qsplitscalebias=st.sampled_from( + [ + (ShardingType.TABLE_WISE.value, False), + (ShardingType.TABLE_WISE.value, True), + (ShardingType.ROW_WISE.value, True), + (ShardingType.COLUMN_WISE.value, True), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_sharder_bad_param_config( + self, sharding_type_qsplitscalebias: Tuple[str, bool] + ) -> None: + ( + sharding_type, + quant_state_dict_split_scale_bias, + ) = sharding_type_qsplitscalebias + device = torch.device("cuda:0") model = TestSparseNN( tables=self.tables, @@ -620,12 +754,16 @@ def test_sharder_bad_param_config(self) -> None: dense_device=device, sparse_device=torch.device("meta"), ) - quant_model = _quantize(model, inplace=True) + quant_model = quantize( + model, + inplace=True, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) sharders = [ cast( ModuleSharder[torch.nn.Module], TestQuantEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, + sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[ table.name for table in self.tables[:-1] diff --git a/torchrec/distributed/tests/test_quant_pruning.py b/torchrec/distributed/tests/test_quant_pruning.py new file mode 100644 index 000000000..86ce67048 --- /dev/null +++ b/torchrec/distributed/tests/test_quant_pruning.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import io +import unittest +from typing import Dict, List, Tuple + +import torch +from torchrec.distributed.embedding_types import ShardingType +from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec +from torchrec.distributed.test_utils.infer_utils import ( + assert_close, + assert_weight_spec, + create_test_model, + create_test_model_ebc_only_no_quantize, + model_input_to_forward_args, + prep_inputs, + quantize, + shard_qebc, +) +from torchrec.fx import symbolic_trace +from torchrec.inference.modules import set_pruning_data +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +SPARSE_NN_EBC_MODULE = "_module.sparse.ebc" +SEQUENTIAL_NN_EBC_MODULE = "0" + + +def prune_and_quantize_model( + model: torch.nn.Module, + pruning_ebc_dict: Dict[str, int], +) -> torch.nn.Module: + set_pruning_data(model, pruning_ebc_dict) + + quant_state_dict_split_scale_bias = True + quant_model = quantize( + module=model, + inplace=False, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, + ) + + return quant_model + + +def create_quant_and_sharded_ebc_models( + num_embedding: int, + emb_dim: int, + world_size: int, + batch_size: int, + sharding_type: ShardingType, + device: torch.device, + feature_processor: bool = False, +) -> Tuple[torch.nn.Module, torch.nn.Module, Dict[str, int]]: + mi = create_test_model_ebc_only_no_quantize( + num_embedding, + emb_dim, + world_size, + batch_size, + num_features=1, + num_weighted_features=0, + dense_device=device, + sparse_device=device, + feature_processor=feature_processor, + ) + mi.model.to(device) + num_rows_post_pruned = num_embedding // 2 + + pruning_ebc_dict = {"table_0": num_rows_post_pruned} + quant_model = prune_and_quantize_model(mi.model, pruning_ebc_dict) + + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + quant_model = quant_model[0] + mi.quant_model = quant_model + + sharded_model = shard_qebc( + mi, + sharding_type=sharding_type, + device=device, + expected_shards=None, + feature_processor=feature_processor, + ) + + sharded_model.load_state_dict(quant_model.state_dict()) + + return quant_model, sharded_model, pruning_ebc_dict + + +class QuantPruneTest(unittest.TestCase): + def check_tbe_pruned( + self, sharded_model: torch.nn.Module, pruned_dict: Dict[str, int] + ) -> None: + for module in sharded_model.modules(): + if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen": + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + for i, spec in enumerate(module.embedding_specs): + if spec[0] in pruned_dict: + self.assertEqual( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + module.split_embedding_weights()[i][0].size(0), + pruned_dict[spec[0]], + ) + self.assertEqual( + spec[1], + pruned_dict[spec[0]], + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_qebc_pruned_tw(self) -> None: + batch_size: int = 4 + world_size = 2 + local_device = torch.device("cuda:0") + + num_embedding = 100 + emb_dim = 64 + pruned_entry = 40 + + # hash, dim, pruned_hash + table_specs: List[Tuple[int, int, int]] = [ + (num_embedding, emb_dim, num_embedding), + (num_embedding, emb_dim, num_embedding - pruned_entry), + ] + pruning_ebc_dict: Dict[str, int] = {} + pruning_ebc_dict["table_1"] = num_embedding - pruned_entry + + mi = create_test_model( + num_embedding, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + num_features=len(table_specs), + pruning_dict=pruning_ebc_dict, + ) + + expected_shards = [ + [ + ( + (0, 0, table_specs[0][2], table_specs[0][1]), + "rank:0/cuda:0", + ), + ], + [ + ( + (0, 0, table_specs[1][2], table_specs[1][1]), + "rank:1/cuda:1", + ), + ], + ] + + quant_model = mi.quant_model + quant_state_dict = quant_model.state_dict() + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=expected_shards, + ) + + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(quant_state_dict) + quant_output = quant_model(*inputs[0]) + sharded_output = sharded_model(*inputs[0]) + assert_close(quant_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + buffer = io.BytesIO() + torch.jit.save(gm_script, buffer) + buffer.seek(0) + loaded_gm_script = torch.jit.load(buffer) + gm_script_output = loaded_gm_script(*inputs[0]) + assert_close(quant_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0", "table_1"], + ShardingType.TABLE_WISE.value, + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_qebc_pruned_tw_one_ebc(self) -> None: + batch_size: int = 1 + world_size: int = 2 + local_device = torch.device("cuda:0") + num_embedding = 200 + emb_dim = 10 + sharding_type = ShardingType.TABLE_WISE + + quant_model, sharded_model, pruned_dict = create_quant_and_sharded_ebc_models( + num_embedding=num_embedding, + emb_dim=emb_dim, + world_size=world_size, + batch_size=batch_size, + sharding_type=sharding_type, + device=local_device, + ) + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.tensor([0, 1, 2], dtype=torch.int32).cuda(), + lengths=torch.tensor([1, 1, 1], dtype=torch.int32).cuda(), + weights=None, + ) + + q_output = quant_model(kjt) + s_output = sharded_model(kjt) + + assert_close(q_output["feature_0"], s_output["feature_0"]) + + self.check_tbe_pruned(sharded_model, pruned_dict) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_qebc_pruned_cw(self) -> None: + batch_size: int = 4 + world_size = 2 + local_device = torch.device("cuda:0") + + num_embedding = 200 + emb_dim = 512 + pruned_entry = 100 + + # hash, dim, pruned_hash + table_specs: List[Tuple[int, int, int]] = [ + (num_embedding, emb_dim, num_embedding - pruned_entry), + ] + pruning_ebc_dict: Dict[str, int] = {"table_0": num_embedding - pruned_entry} + + mi = create_test_model( + num_embedding, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + num_features=len(table_specs), + pruning_dict=pruning_ebc_dict, + ) + + expected_shards = [ + [ + ( + (0, 0, table_specs[0][2], table_specs[0][1] // 4), + "rank:0/cuda:0", + ), + ( + (0, 128, table_specs[0][2], table_specs[0][1] // 4), + "rank:1/cuda:1", + ), + ( + (0, 256, table_specs[0][2], table_specs[0][1] // 4), + "rank:0/cuda:0", + ), + ( + (0, 384, table_specs[0][2], table_specs[0][1] // 4), + "rank:1/cuda:1", + ), + ], + ] + + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.COLUMN_WISE, + device=local_device, + expected_shards=expected_shards, + ) + + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(mi.quant_model.state_dict()) + quant_output = mi.quant_model(*inputs[0]) + sharded_output = sharded_model(*inputs[0]) + assert_close(quant_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + buffer = io.BytesIO() + torch.jit.save(gm_script, buffer) + buffer.seek(0) + loaded_gm_script = torch.jit.load(buffer) + gm_script_output = loaded_gm_script(*inputs[0]) + assert_close(quant_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.COLUMN_WISE.value, + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_qebc_pruned_cw_one_ebc(self) -> None: + batch_size: int = 1 + world_size: int = 2 + local_device = torch.device("cuda:0") + num_embedding = 200 + emb_dim = 512 + sharding_type = ShardingType.COLUMN_WISE + + quant_model, sharded_model, pruned_dict = create_quant_and_sharded_ebc_models( + num_embedding=num_embedding, + emb_dim=emb_dim, + world_size=world_size, + batch_size=batch_size, + sharding_type=sharding_type, + device=local_device, + ) + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.tensor([0, 1, 2, 59, 60, 99], dtype=torch.int32).cuda(), + lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int32).cuda(), + weights=None, + ) + + q_output = quant_model(kjt) + s_output = sharded_model(kjt) + + assert_close(q_output["feature_0"], s_output["feature_0"]) + + self.check_tbe_pruned(sharded_model, pruned_dict) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + def test_fpqebc_pruned_tw_one_fpebc(self) -> None: + batch_size: int = 1 + world_size: int = 2 + local_device = torch.device("cuda:0") + num_embedding = 200 + emb_dim = 512 + sharding_type = ShardingType.COLUMN_WISE + + quant_model, sharded_model, pruned_dict = create_quant_and_sharded_ebc_models( + num_embedding=num_embedding, + emb_dim=emb_dim, + world_size=world_size, + batch_size=batch_size, + sharding_type=sharding_type, + device=local_device, + feature_processor=True, + ) + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0"], + values=torch.tensor([0, 1, 2, 59, 60, 99], dtype=torch.int32).cuda(), + lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int32).cuda(), + weights=None, + ) + + q_output = quant_model(kjt) + s_output = sharded_model(kjt) + + assert_close(q_output["feature_0"], s_output["feature_0"]) + + self.check_tbe_pruned(sharded_model, pruned_dict) diff --git a/torchrec/distributed/tests/test_quant_sequence_model_parallel.py b/torchrec/distributed/tests/test_quant_sequence_model_parallel.py index 3b33a1833..ffaa937d0 100644 --- a/torchrec/distributed/tests/test_quant_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_sequence_model_parallel.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import cast, List, Optional, Type @@ -15,7 +17,7 @@ from torch import nn, quantization as quant from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder -from torchrec.distributed.shard import shard +from torchrec.distributed.shard import _shard_modules from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNNBase from torchrec.distributed.test_utils.test_model_parallel_base import ( InferenceModelParallelTestBase, @@ -26,11 +28,18 @@ from torchrec.modules.embedding_modules import EmbeddingCollection from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, + quant_prep_enable_quant_state_dict_split_scale_bias_for_types, ) from torchrec.test_utils import seed_and_log, skip_if_asan_class -def _quantize(module: nn.Module) -> nn.Module: +def _quantize( + module: nn.Module, quant_state_dict_split_scale_bias: bool = False +) -> nn.Module: + if quant_state_dict_split_scale_bias: + quant_prep_enable_quant_state_dict_split_scale_bias_for_types( + module, [EmbeddingCollection] + ) qconfig = quant.QConfig( activation=quant.PlaceholderObserver, weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8), @@ -80,9 +89,15 @@ class QuantSequenceModelParallelTest(InferenceModelParallelTestBase): EmbeddingComputeKernel.QUANT.value, ] ), + quant_state_dict_split_scale_bias=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_sharding_nccl_tw(self, sharding_type: str, kernel_type: str) -> None: + def test_sharding_nccl_tw( + self, + sharding_type: str, + kernel_type: str, + quant_state_dict_split_scale_bias: bool, + ) -> None: self._test_sharding( sharders=[ TestQuantECSharder( @@ -91,6 +106,7 @@ def test_sharding_nccl_tw(self, sharding_type: str, kernel_type: str) -> None: ) ], backend="nccl", + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, ) @seed_and_log @@ -119,6 +135,7 @@ def _test_sharding( world_size: int = 2, local_size: Optional[int] = None, model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + quant_state_dict_split_scale_bias: bool = False, ) -> None: self._test_sharded_forward( world_size=world_size, @@ -128,6 +145,9 @@ def _test_sharding( # pyre-ignore [6] sharders=sharders, quantize_callable=_quantize, + quantize_callable_kwargs={ + "quant_state_dict_split_scale_bias": quant_state_dict_split_scale_bias + }, ) @unittest.skipIf( @@ -142,9 +162,12 @@ def _test_sharding( torch.float, ] ), + quant_state_dict_split_scale_bias=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) - def test_quant_pred_shard(self, output_type: torch.dtype) -> None: + def test_quant_pred_shard( + self, output_type: torch.dtype, quant_state_dict_split_scale_bias: bool + ) -> None: device = torch.device("cuda:0") # wrap in sequential because _quantize only applies to submodules... @@ -152,7 +175,7 @@ def test_quant_pred_shard(self, output_type: torch.dtype) -> None: quant_model = _quantize(model) - sharded_quant_model, _sharded_params = shard( + sharded_quant_model = _shard_modules( module=quant_model, sharders=[ cast( @@ -175,6 +198,8 @@ def test_quant_pred_shard(self, output_type: torch.dtype) -> None: num_float_features=10, tables=self.tables, weighted_tables=[], + indices_dtype=torch.int32, + lengths_dtype=torch.int32, ) local_batch = local_batch.to(device) sharded_quant_model(local_batch.idlist_features) diff --git a/torchrec/distributed/tests/test_quantize.py b/torchrec/distributed/tests/test_quantize.py deleted file mode 100644 index d6c39127f..000000000 --- a/torchrec/distributed/tests/test_quantize.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest - -import hypothesis.strategies as st -import torch -import torch.distributed as dist -import torch.quantization as quant -from hypothesis import given, settings, Verbosity -from torchrec.distributed.embedding_lookup import ( - BatchedDenseEmbedding, - BatchedDenseEmbeddingBag, - BatchedFusedEmbedding, - BatchedFusedEmbeddingBag, - GroupedEmbeddingsLookup, - GroupedPooledEmbeddingsLookup, -) -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - GroupedEmbeddingConfig, - ShardedEmbeddingTable, - ShardedTensorMetadata, - ShardMetadata, -) -from torchrec.distributed.quant_embedding_kernel import ( - QuantBatchedEmbedding, - QuantBatchedEmbeddingBag, -) -from torchrec.modules.embedding_configs import DataType, PoolingType -from torchrec.test_utils import get_free_port - - -def quantize_sharded_embeddings( - module: torch.nn.Module, dtype: torch.dtype -) -> torch.nn.Module: - qconfig = quant.QConfigDynamic( - activation=quant.PlaceholderObserver, - weight=quant.PlaceholderObserver.with_args(dtype=dtype), - ) - return quant.quantize_dynamic( - module, - qconfig_spec={ - BatchedFusedEmbeddingBag: qconfig, - BatchedDenseEmbeddingBag: qconfig, - BatchedDenseEmbedding: qconfig, - BatchedFusedEmbedding: qconfig, - }, - mapping={ - BatchedFusedEmbeddingBag: QuantBatchedEmbeddingBag, - BatchedDenseEmbeddingBag: QuantBatchedEmbeddingBag, - BatchedDenseEmbedding: QuantBatchedEmbedding, - BatchedFusedEmbedding: QuantBatchedEmbedding, - }, - inplace=False, - ) - - -class QuantizeKernelTest(unittest.TestCase): - def setUp(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - self.device = torch.device("cuda:0") - backend = "nccl" - torch.cuda.set_device(self.device) - dist.init_process_group(backend=backend) - - def tearDown(self) -> None: - dist.destroy_process_group() - del os.environ["NCCL_SOCKET_IFNAME"] - super().tearDown() - - def _create_config( - self, compute_kernel: EmbeddingComputeKernel - ) -> GroupedEmbeddingConfig: - num_embedding_tables = 2 - embedding_tables = [] - for i in range(num_embedding_tables): - rows = (i + 1) * 10 - cols = 16 - local_metadata = ShardMetadata( - shard_offsets=[0, 0], - shard_sizes=[rows, cols], - placement=torch.distributed._remote_device("rank:0/cuda:0"), - ) - embedding_tables.append( - ShardedEmbeddingTable( - num_embeddings=rows, - embedding_dim=cols, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - pooling=PoolingType.MEAN, - is_weighted=False, - has_feature_processor=False, - local_rows=rows, - local_cols=cols, - compute_kernel=compute_kernel, - local_metadata=local_metadata, - global_metadata=ShardedTensorMetadata( - shards_metadata=[local_metadata], - size=torch.Size([rows, cols]), - ), - weight_init_max=1.0, - weight_init_min=0.0, - ) - ) - return GroupedEmbeddingConfig( - data_type=DataType.FP32, - pooling=PoolingType.MEAN, - is_weighted=False, - has_feature_processor=False, - compute_kernel=compute_kernel, - embedding_tables=embedding_tables, - ) - - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "Not enough GPUs, this test requires at least one GPU", - ) - # pyre-ignore [56] - @given( - compute_kernel=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE, - EmbeddingComputeKernel.FUSED, - ] - ), - dtype=st.sampled_from( - [ - torch.qint8, - torch.quint4x2, - torch.quint2x4, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) - def test_quantize_embedding_bag_kernels( - self, compute_kernel: EmbeddingComputeKernel, dtype: torch.dtype - ) -> None: - config = self._create_config(compute_kernel) - sharded = GroupedPooledEmbeddingsLookup( - grouped_configs=[config], - grouped_score_configs=[], - device=torch.device("cuda:0"), - ) - - quantized = quantize_sharded_embeddings(sharded, dtype=dtype) - - for _, buffer in quantized.named_buffers(): - self.assertEqual(buffer.dtype, torch.uint8) - - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "Not enough GPUs, this test requires at least one GPU", - ) - # pyre-ignore [56] - @given( - compute_kernel=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE, - EmbeddingComputeKernel.FUSED, - ] - ), - dtype=st.sampled_from( - [ - torch.qint8, - torch.quint4x2, - torch.quint2x4, - ] - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) - def test_quantize_embedding_kernels( - self, compute_kernel: EmbeddingComputeKernel, dtype: torch.dtype - ) -> None: - config = self._create_config(compute_kernel) - sharded = GroupedEmbeddingsLookup( - grouped_configs=[config], - device=torch.device("cuda:0"), - ) - - quantized = quantize_sharded_embeddings(sharded, dtype=dtype) - - for _, buffer in quantized.named_buffers(): - self.assertEqual(buffer.dtype, torch.uint8) diff --git a/torchrec/distributed/tests/test_sequence_model.py b/torchrec/distributed/tests/test_sequence_model.py index bea5fbc74..f5234fbc7 100644 --- a/torchrec/distributed/tests/test_sequence_model.py +++ b/torchrec/distributed/tests/test_sequence_model.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, Tuple, Union import torch @@ -154,6 +156,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -195,9 +198,11 @@ def __init__( ) self.over = nn.Linear( in_features=8 - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + self.tower_0.interaction.linear.out_features - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `out_features`. + self.tower_1.interaction.linear.out_features, out_features=16, device=dense_device, @@ -256,6 +261,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -282,7 +288,9 @@ def forward( input: ModelInput, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: dense_r = self.dense(input.float_features) - sparse_r = self.sparse(input.idlist_features, input.float_features.size(0)) + # multiply the sparse output by 10 since the model output is not sensitive to the + # embedding output. It won't catch the unexpected embedding output without this + sparse_r = 10 * self.sparse(input.idlist_features, input.float_features.size(0)) over_r = self.over(dense_r, sparse_r) pred = torch.sigmoid(torch.mean(over_r, dim=1)) if self.training: @@ -300,7 +308,8 @@ def __init__( sharding_type: str, kernel_type: str, qcomms_config: Optional[QCommsConfig] = None, - variable_batch_size: bool = False, + fused_params: Optional[Dict[str, Any]] = None, + use_index_dedup: bool = False, ) -> None: self._sharding_type = sharding_type self._kernel_type = kernel_type @@ -309,11 +318,15 @@ def __init__( if qcomms_config is not None: qcomm_codecs_registry = get_qcomm_codecs_registry(qcomms_config) - fused_params = {"learning_rate": 0.1} + if fused_params is None: + fused_params = {} + if "learning_rate" not in fused_params: + fused_params["learning_rate"] = 0.1 + super().__init__( fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry, - variable_batch_size=variable_batch_size, + use_index_dedup=use_index_dedup, ) """ diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index 861d1502c..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import Any, Dict, List, Optional, Tuple, Type @@ -36,11 +38,7 @@ class SequenceModelParallelTest(MultiProcessTestBase): ) # pyre-fixme[56] @given( - sharding_type=st.sampled_from( - [ - ShardingType.ROW_WISE.value, - ] - ), + sharding_type=st.just(ShardingType.ROW_WISE.value), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.DENSE.value, @@ -59,14 +57,14 @@ class SequenceModelParallelTest(MultiProcessTestBase): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), - variable_batch_size=st.sampled_from([True, False]), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_rw( self, sharding_type: str, @@ -87,7 +85,6 @@ def test_sharding_nccl_rw( sharding_type=sharding_type, kernel_type=kernel_type, qcomms_config=qcomms_config, - variable_batch_size=variable_batch_size, ) ], backend="nccl", @@ -102,17 +99,9 @@ def test_sharding_nccl_rw( ) # pyre-fixme[56] @given( - sharding_type=st.sampled_from( - [ - ShardingType.DATA_PARALLEL.value, - ] - ), - kernel_type=st.sampled_from( - [ - EmbeddingComputeKernel.DENSE.value, - ] - ), - apply_optimizer_in_backward_config=st.sampled_from([None]), + sharding_type=st.just(ShardingType.DATA_PARALLEL.value), + kernel_type=st.just(EmbeddingComputeKernel.DENSE.value), + apply_optimizer_in_backward_config=st.just(None), # TODO - need to enable optimizer overlapped behavior for data_parallel tables # apply_optimizer_in_backward_config=st.booleans(), ) @@ -142,11 +131,7 @@ def test_sharding_nccl_dp( ) # pyre-fixme[56] @given( - sharding_type=st.sampled_from( - [ - ShardingType.TABLE_WISE.value, - ] - ), + sharding_type=st.just(ShardingType.TABLE_WISE.value), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.DENSE.value, @@ -165,14 +150,14 @@ def test_sharding_nccl_dp( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), - variable_batch_size=st.sampled_from([True, False]), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_tw( self, sharding_type: str, @@ -193,7 +178,6 @@ def test_sharding_nccl_tw( sharding_type=sharding_type, kernel_type=kernel_type, qcomms_config=qcomms_config, - variable_batch_size=variable_batch_size, ) ], backend="nccl", @@ -208,11 +192,7 @@ def test_sharding_nccl_tw( ) # pyre-fixme[56] @given( - sharding_type=st.sampled_from( - [ - ShardingType.COLUMN_WISE.value, - ] - ), + sharding_type=st.just(ShardingType.COLUMN_WISE.value), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.DENSE.value, @@ -223,14 +203,14 @@ def test_sharding_nccl_tw( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), - variable_batch_size=st.sampled_from([True, False]), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_cw( self, sharding_type: str, @@ -249,18 +229,79 @@ def test_sharding_nccl_cw( TestEmbeddingCollectionSharder( sharding_type=sharding_type, kernel_type=kernel_type, - variable_batch_size=variable_batch_size, ) ], backend="nccl", constraints={ - table.name: ParameterConstraints(min_partition=16) + table.name: ParameterConstraints(min_partition=8) for table in self.tables }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + index_dedup=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_sharding_variable_batch( + self, + sharding_type: str, + index_dedup: bool, + ) -> None: + self._test_sharding( + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED.value, + use_index_dedup=index_dedup, + ) + ], + backend="nccl", + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + variable_batch_per_feature=True, + ) + + # pyre-fixme[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharding_empty_rank(self) -> None: + table = self.tables[0] + embedding_groups = {"group_0": table.feature_names} + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=2, + model_class=TestSequenceSparseNN, + tables=[table], + embedding_groups=embedding_groups, + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + ) + ], + optim=EmbOptimType.EXACT_SGD, + backend="nccl", + variable_batch_size=True, + ) + @seed_and_log def setUp(self) -> None: super().setUp() @@ -293,9 +334,11 @@ def setUp(self) -> None: self.embedding_groups = { "group_0": [ - f"{feature}@{table.name}" - if feature in self.shared_features - else feature + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) for table in self.tables for feature in table.feature_names ] @@ -314,6 +357,47 @@ def _test_sharding( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ] = None, variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, ) -> None: self._run_multi_process_test( callable=sharding_single_rank_test, @@ -329,4 +413,7 @@ def _test_sharding( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", ) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py b/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py index cb77fc27a..7ea296a91 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import Any, Dict, List, Optional, Tuple, Type @@ -60,13 +62,13 @@ class SequenceModelParallelHierarchicalTest(MultiProcessTestBase): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) def test_seq_emb_tower_nccl( self, sharding_type: str, diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py new file mode 100644 index 000000000..26ca8c55b --- /dev/null +++ b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest +from typing import cast, OrderedDict + +import hypothesis.strategies as st +import torch +from hypothesis import given, settings, Verbosity +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.test_utils.test_model_parallel_base import ( + ModelParallelSingleRankBase, +) +from torchrec.distributed.tests.test_sequence_model import ( + TestEmbeddingCollectionSharder, + TestSequenceSparseNN, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import DataType, EmbeddingConfig + + +class SequenceModelParallelStateDictTest(ModelParallelSingleRankBase): + def setUp(self, backend: str = "nccl") -> None: + self.shared_features = [] + self.embedding_groups = {} + + super().setUp(backend=backend) + + def _create_tables(self) -> None: + num_features = 4 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + ) + for i in range(shared_features) + ] + + self.tables += initial_tables + shared_features_tables + self.shared_features += [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups["group_0"] = [ + (f"{feature}@{table.name}" if feature in self.shared_features else feature) + for table in self.tables + for feature in table.feature_names + ] + + def _create_model(self) -> nn.Module: + return TestSequenceSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + embedding_groups=self.embedding_groups, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_load_state_dict( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + ) -> None: + sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + ), + ), + ] + models, batch = self._generate_dmps_and_batch(sharders) + m1, m2 = models + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + + # validate the models are equivalent + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch) + self._compare_models(m1, m2) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_numerical_equivalence_between_kernel_types( + self, + sharding_type: str, + kernel_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + self._set_table_weights_precision(dtype) + fused_params = { + "stochastic_rounding": stochastic_rounding, + "cache_precision": dtype, + } + + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=EmbeddingComputeKernel.FUSED.value, + fused_params=fused_params, + ), + ), + ] + sharders = [ + cast( + ModuleSharder[nn.Module], + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ), + ), + ] + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + (model, _), batch = self._generate_dmps_and_batch(sharders) + + # load the baseline model's state_dict onto the new model + model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + if is_training: + for _ in range(4): + self._train_models(fused_model, model, batch) + # the problem here is output is FP32, but weights are FP16 + # so we should actually use FP16 atol and rtol to check close + if not is_training or not stochastic_rounding: + self._eval_models( + fused_model, model, batch, is_deterministic=not stochastic_rounding + ) + self._compare_models( + fused_model, model, is_deterministic=not stochastic_rounding + ) diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 291039ef0..5dc18885a 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import unittest from typing import Any, Dict, List, Optional @@ -13,12 +15,24 @@ import torch from hypothesis import given, settings, Verbosity from torchrec import distributed as trec_dist +from torchrec.distributed.quant_embedding import ( + QuantManagedCollisionEmbeddingCollectionSharder, +) from torchrec.distributed.sharding_plan import ( column_wise, construct_module_sharding_plan, data_parallel, + EmbeddingBagCollectionSharder, + EmbeddingCollectionSharder, + FeatureProcessedEmbeddingBagCollectionSharder, + FusedEmbeddingBagCollectionSharder, get_module_to_default_sharders, + grid_shard, + ManagedCollisionEmbeddingBagCollectionSharder, + ManagedCollisionEmbeddingCollectionSharder, ParameterShardingGenerator, + QuantEmbeddingBagCollectionSharder, + QuantEmbeddingCollectionSharder, row_wise, table_row_wise, table_wise, @@ -30,18 +44,34 @@ ) from torchrec.distributed.test_utils.test_sharding import copy_state_dict from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, EnumerableShardingSpec, - ModuleShardingPlan, ParameterSharding, ShardingEnv, + ShardingPlan, ShardingType, ShardMetadata, ) -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, + QuantManagedCollisionEmbeddingCollection, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import skip_if_asan_class +from torchrec.types import DataType def _test_sharding( @@ -51,7 +81,7 @@ def _test_sharding( world_size: int, kjt_input_per_rank: List[KeyedJaggedTensor], backend: str, - module_sharding_plan: ModuleShardingPlan, + module_sharding_plan: EmbeddingModuleShardingPlan, local_size: Optional[int] = None, ) -> None: trec_dist.comm_ops.set_gradient_division(False) @@ -72,6 +102,8 @@ def _test_sharding( sharded_model = sharder.shard( module=model, params=module_sharding_plan, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. env=ShardingEnv.from_process_group(ctx.pg), device=ctx.device, ) @@ -92,11 +124,6 @@ def _test_sharding( @skip_if_asan_class class ConstructParameterShardingAndShardTest(MultiProcessTestBase): - # TODO: Remove GPU check after T136512190 is fixed - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) # pyre-fixme[56] @given( per_param_sharding=st.sampled_from( @@ -119,12 +146,15 @@ class ConstructParameterShardingAndShardTest(MultiProcessTestBase): }, ] ), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_parameter_sharding_ebc( self, per_param_sharding: Dict[str, ParameterShardingGenerator], + data_type: DataType, ) -> None: + WORLD_SIZE = 2 embedding_bag_config = [ @@ -133,12 +163,14 @@ def test_parameter_sharding_ebc( feature_names=["feature_0"], embedding_dim=16, num_embeddings=4, + data_type=data_type, ), EmbeddingBagConfig( name="table_1", feature_names=["feature_1"], embedding_dim=16, num_embeddings=4, + data_type=data_type, ), ] @@ -187,21 +219,23 @@ def test_parameter_sharding_ebc( world_size=WORLD_SIZE, tables=embedding_bag_config, initial_state_dict={ - "embedding_bags.table_0.weight": torch.Tensor( + "embedding_bags.table_0.weight": torch.tensor( [ [1] * 16, [2] * 16, [3] * 16, [4] * 16, - ] + ], + dtype=data_type_to_dtype(data_type), ), - "embedding_bags.table_1.weight": torch.Tensor( + "embedding_bags.table_1.weight": torch.tensor( [ [101] * 16, [102] * 16, [103] * 16, [104] * 16, - ] + ], + dtype=data_type_to_dtype(data_type), ), }, kjt_input_per_rank=kjt_input_per_rank, @@ -211,15 +245,20 @@ def test_parameter_sharding_ebc( class ConstructParameterShardingTest(unittest.TestCase): - def test_construct_module_sharding_plan(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_construct_module_sharding_plan(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=256, num_embeddings=32 * 32, + data_type=data_type, ) - for idx in range(5) + for idx in range(6) ] expected = { @@ -546,6 +585,95 @@ def test_construct_module_sharding_plan(self) -> None: ] ), ), + "table_5": ParameterSharding( + sharding_type="grid_shard", + compute_kernel="dense", + ranks=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[128, 128], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[128, 0], + shard_sizes=[128, 128], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[256, 0], + shard_sizes=[128, 128], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[384, 0], + shard_sizes=[128, 128], + placement="rank:3/cuda:3", + ), + ShardMetadata( + shard_offsets=[512, 0], + shard_sizes=[128, 128], + placement="rank:4/cuda:4", + ), + ShardMetadata( + shard_offsets=[640, 0], + shard_sizes=[128, 128], + placement="rank:5/cuda:5", + ), + ShardMetadata( + shard_offsets=[768, 0], + shard_sizes=[128, 128], + placement="rank:6/cuda:6", + ), + ShardMetadata( + shard_offsets=[896, 0], + shard_sizes=[128, 128], + placement="rank:7/cuda:7", + ), + ShardMetadata( + shard_offsets=[0, 128], + shard_sizes=[128, 128], + placement="rank:8/cuda:0", + ), + ShardMetadata( + shard_offsets=[128, 128], + shard_sizes=[128, 128], + placement="rank:9/cuda:1", + ), + ShardMetadata( + shard_offsets=[256, 128], + shard_sizes=[128, 128], + placement="rank:10/cuda:2", + ), + ShardMetadata( + shard_offsets=[384, 128], + shard_sizes=[128, 128], + placement="rank:11/cuda:3", + ), + ShardMetadata( + shard_offsets=[512, 128], + shard_sizes=[128, 128], + placement="rank:12/cuda:4", + ), + ShardMetadata( + shard_offsets=[640, 128], + shard_sizes=[128, 128], + placement="rank:13/cuda:5", + ), + ShardMetadata( + shard_offsets=[768, 128], + shard_sizes=[128, 128], + placement="rank:14/cuda:6", + ), + ShardMetadata( + shard_offsets=[896, 128], + shard_sizes=[128, 128], + placement="rank:15/cuda:7", + ), + ] + ), + ), } module_sharding_plan = construct_module_sharding_plan( @@ -556,19 +684,150 @@ def test_construct_module_sharding_plan(self) -> None: "table_2": row_wise(), "table_3": column_wise(ranks=[8, 9]), "table_4": table_row_wise(host_index=3), + "table_5": grid_shard(host_indexes=[0, 1]), }, local_size=8, world_size=32, + device_type="cuda", ) self.assertDictEqual(expected, module_sharding_plan) - def test_column_wise(self) -> None: + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_table_wise_set_device(self, data_type: DataType) -> None: + embedding_bag_config = [ EmbeddingBagConfig( name=f"table_{idx}", feature_names=[f"feature_{idx}"], embedding_dim=64, num_embeddings=4096, + data_type=data_type, + ) + for idx in range(2) + ] + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": table_wise(rank=0, device="cpu"), + "table_1": table_wise(rank=1, device="cpu"), + }, + local_size=2, + world_size=2, + device_type="cuda", + ) + + # Make sure per_param_sharding setting override the default device_type + self.assertEqual( + # pyre-ignore[16] + module_sharding_plan["table_0"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cpu", + ) + + self.assertEqual( + module_sharding_plan["table_1"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cpu", + ) + + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None: + + embedding_bag_config = [ + EmbeddingBagConfig( + name=f"table_{idx}", + feature_names=[f"feature_{idx}"], + embedding_dim=64, + num_embeddings=4096, + data_type=data_type, + ) + for idx in range(2) + ] + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": row_wise( + sizes_placement=( + [2048, 1024, 1024], + ["cpu", "cuda", "cuda"], + ) + ), + "table_1": row_wise( + sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"]) + ), + }, + local_size=1, + world_size=2, + device_type="cuda", + ) + + # Make sure per_param_sharding setting override the default device_type + device_table_0_shard_0 = ( + # pyre-ignore[16] + module_sharding_plan["table_0"] + .sharding_spec.shards[0] + .placement + ) + self.assertEqual( + device_table_0_shard_0.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_0_shard_0.rank(), + 0, + ) + for i in range(1, 3): + device_table_0_shard_i = ( + module_sharding_plan["table_0"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_0_shard_i.device().type, + "cuda", + ) + # first rank is assigned to cpu so index = rank - 1 + self.assertEqual( + device_table_0_shard_i.device().index, + i - 1, + ) + self.assertEqual( + device_table_0_shard_i.rank(), + i, + ) + for i in range(3): + device_table_1_shard_i = ( + module_sharding_plan["table_1"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_1_shard_i.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_1_shard_i.rank(), + 0, + ) + + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_column_wise(self, data_type: DataType) -> None: + + embedding_bag_config = [ + EmbeddingBagConfig( + name=f"table_{idx}", + feature_names=[f"feature_{idx}"], + embedding_dim=64, + num_embeddings=4096, + data_type=data_type, ) for idx in range(2) ] @@ -580,6 +839,7 @@ def test_column_wise(self) -> None: }, local_size=2, world_size=2, + device_type="cuda", ) expected = { "table_0": ParameterSharding( @@ -622,3 +882,119 @@ def test_column_wise(self) -> None: ), } self.assertDictEqual(expected, module_sharding_plan) + + +class ShardingPlanTest(unittest.TestCase): + def test_str(self) -> None: + plan = ShardingPlan( + { + "ebc": EmbeddingModuleShardingPlan( + { + "user_id": ParameterSharding( + sharding_type="table_wise", + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[4096, 32], + placement="rank:0/cuda:0", + ), + ] + ), + ), + "movie_id": ParameterSharding( + sharding_type="row_wise", + compute_kernel="dense", + ranks=[0, 1], + sharding_spec=EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[2048, 0], + shard_sizes=[2048, 32], + placement="rank:0/cuda:1", + ), + ] + ), + ), + } + ) + } + ) + expected = """module: ebc + + param | sharding type | compute kernel | ranks +-------- | ------------- | -------------- | ------ +user_id | table_wise | dense | [0] +movie_id | row_wise | dense | [0, 1] + + param | shard offsets | shard sizes | placement +-------- | ------------- | ----------- | ------------- +user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 +movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 +movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 +""" + self.maxDiff = None + for i in range(len(expected.splitlines())): + self.assertEqual( + expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() + ) + + def test_module_to_default_sharders(self) -> None: + default_sharder_map = get_module_to_default_sharders() + self.assertCountEqual( + default_sharder_map, + [ + EmbeddingBagCollection, + FeatureProcessedEmbeddingBagCollection, + EmbeddingCollection, + FusedEmbeddingBagCollection, + QuantEmbeddingBagCollection, + QuantEmbeddingCollection, + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, + QuantManagedCollisionEmbeddingCollection, + ], + ) + self.assertIsInstance( + default_sharder_map[EmbeddingBagCollection], EmbeddingBagCollectionSharder + ) + self.assertIsInstance( + default_sharder_map[FeatureProcessedEmbeddingBagCollection], + FeatureProcessedEmbeddingBagCollectionSharder, + ) + self.assertIsInstance( + default_sharder_map[EmbeddingCollection], EmbeddingCollectionSharder + ) + self.assertIsInstance( + default_sharder_map[FusedEmbeddingBagCollection], + FusedEmbeddingBagCollectionSharder, + ) + self.assertIsInstance( + default_sharder_map[QuantEmbeddingBagCollection], + QuantEmbeddingBagCollectionSharder, + ) + self.assertIsInstance( + default_sharder_map[QuantEmbeddingCollection], + QuantEmbeddingCollectionSharder, + ) + self.assertIsInstance( + default_sharder_map[ManagedCollisionEmbeddingBagCollection], + ManagedCollisionEmbeddingBagCollectionSharder, + ) + + self.assertIsInstance( + default_sharder_map[ManagedCollisionEmbeddingCollection], + ManagedCollisionEmbeddingCollectionSharder, + ) + + self.assertIsInstance( + default_sharder_map[QuantManagedCollisionEmbeddingCollection], + QuantManagedCollisionEmbeddingCollectionSharder, + ) diff --git a/torchrec/distributed/tests/test_shards_wrapper.py b/torchrec/distributed/tests/test_shards_wrapper.py new file mode 100644 index 000000000..7199552dd --- /dev/null +++ b/torchrec/distributed/tests/test_shards_wrapper.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import List, Optional, Union + +import torch +from torch import distributed as dist +from torchrec.distributed.shards_wrapper import LocalShardsWrapper +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.test_utils import seed_and_log, skip_if_asan_class + + +def all_gather_into_tensor( + rank: int, + world_size: int, + backend: str, + expected_result: Union[torch.Tensor, List[torch.Tensor]], + shards_wrapper: List[LocalShardsWrapper], + local_size: Optional[int] = None, + async_op: bool = False, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + local_shards_wrapper = shards_wrapper[ctx.rank] + output_tensor = torch.empty((8, 5), device=torch.device(f"cuda:{ctx.rank}")) + res = dist.all_gather_into_tensor( + output_tensor, local_shards_wrapper, group=ctx.pg, async_op=async_op + ) + if async_op: + res.wait() + torch.testing.assert_close( + output_tensor.cpu(), + expected_result, + ) + + +def all_gather( + rank: int, + world_size: int, + backend: str, + expected_result: Union[torch.Tensor, List[torch.Tensor]], + shards_wrapper: List[LocalShardsWrapper], + local_size: Optional[int] = None, + async_op: bool = False, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + local_shards_wrapper = shards_wrapper[ctx.rank] + tensor_list = [ + torch.zeros((4, 5), dtype=torch.float32, device=f"cuda:{rank}") + for _ in range(2) + ] + res = dist.distributed_c10d.all_gather( + tensor_list, + local_shards_wrapper, + async_op=True, + ) + if async_op: + res.wait() + for tensor, expected in zip(tensor_list, expected_result): + torch.testing.assert_close( + tensor.cpu(), + expected.cpu(), + ) + + +def all_gather_object( + rank: int, + world_size: int, + backend: str, + expected_result: Union[torch.Tensor, List[torch.Tensor]], + shards_wrapper: List[LocalShardsWrapper], + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + local_shards_wrapper = shards_wrapper[ctx.rank] + output = [None] * world_size + dist.distributed_c10d.all_gather_object( + output, + local_shards_wrapper, + ) + for i in range(world_size): + torch.testing.assert_close( + output[i]._local_shards[0], # pyre-ignore[16] + shards_wrapper[i]._local_shards[0], + ) + + +@skip_if_asan_class +class LocalShardsWrapperDistributedTest(MultiProcessTestBase): + @seed_and_log + def setUp(self, backend: str = "nccl") -> None: + super().setUp() + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @unittest.skip("Need to fix circular import errors with Torch") + def test_shards_wrapper_all_gather_into_tensor(self) -> None: + world_size = 2 + backend = "nccl" + shards_0 = [torch.rand((4, 5), device=torch.device("cuda:0"))] + shards_1 = [torch.rand((4, 5), device=torch.device("cuda:1"))] + expected_result = torch.cat( + [torch.cat(shards_0, dim=0).cpu(), torch.cat(shards_1, dim=0).cpu()], dim=0 + ) + offsets = [(0, 0)] + + # shards wrapper for rank 0 and rank 1, offsets don't matter + ls_0 = LocalShardsWrapper(local_shards=shards_0, local_offsets=offsets) + ls_1 = LocalShardsWrapper(local_shards=shards_1, local_offsets=offsets) + + self._run_multi_process_test( + callable=all_gather_into_tensor, + shards_wrapper=[ + ls_0, + ls_1, + ], + expected_result=expected_result, + world_size=world_size, + backend=backend, + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @unittest.skip("Need to fix circular import errors with Torch") + def test_shards_wrapper_all_gather(self) -> None: + world_size = 2 + backend = "nccl" + shards_0 = [torch.rand((4, 5), device=torch.device("cuda:0"))] + shards_1 = [torch.zeros((4, 5), device=torch.device("cuda:1"))] + expected_result = [shards_0[0], shards_1[0]] + offsets = [(0, 0)] + + # shards wrapper for rank 0 and rank 1, offsets don't matter + ls_0 = LocalShardsWrapper(local_shards=shards_0, local_offsets=offsets) + ls_1 = LocalShardsWrapper(local_shards=shards_1, local_offsets=offsets) + + self._run_multi_process_test( + callable=all_gather, + shards_wrapper=[ + ls_0, + ls_1, + ], + expected_result=expected_result, + world_size=world_size, + backend=backend, + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @unittest.skip("Need to fix circular import errors with Torch") + def test_shards_wrapper_all_gather_object(self) -> None: + world_size = 2 + backend = "nccl" + shards_0 = [torch.rand((4, 5), device=torch.device("cuda:0"))] + shards_1 = [torch.zeros((4, 5), device=torch.device("cuda:1"))] + expected_result = [shards_0[0], shards_1[0]] + offsets = [(0, 0)] + + # shards wrapper for rank 0 and rank 1, offsets don't matter + ls_0 = LocalShardsWrapper(local_shards=shards_0, local_offsets=offsets) + ls_1 = LocalShardsWrapper(local_shards=shards_1, local_offsets=offsets) + + self._run_multi_process_test( + callable=all_gather_object, + shards_wrapper=[ + ls_0, + ls_1, + ], + expected_result=expected_result, + world_size=world_size, + backend=backend, + ) diff --git a/torchrec/distributed/tests/test_tensor_pool.py b/torchrec/distributed/tests/test_tensor_pool.py new file mode 100644 index 000000000..24aeb030d --- /dev/null +++ b/torchrec/distributed/tests/test_tensor_pool.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from hypothesis import given, settings, strategies as st +from torchrec.distributed.tensor_pool import TensorPoolSharder + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ( + ObjectPoolShardingPlan, + ObjectPoolShardingType, + ShardedTensor, + ShardingEnv, +) +from torchrec.modules.tensor_pool import TensorPool + + +class TestShardedTensorPool(MultiProcessTestBase): + @staticmethod + def _test_sharded_tensor_pool( + rank: int, world_size: int, enable_uvm: bool = False + ) -> None: + + pool_size = 5 + dim = 4 + backend = "nccl" + dtype = torch.float32 + sharding_plan = ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ) + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + torch.use_deterministic_algorithms(False) + tensor_pool = TensorPool( + pool_size=pool_size, + dim=dim, + dtype=dtype, + enable_uvm=enable_uvm, + ) + + sharded_tensor_pool = TensorPoolSharder().shard( + module=tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + if ctx.rank == 0: + ids = [4, 1] + values = [[0.1, 0.2, 0.3, 0.4], [0.4, 0.5, 0.6, 0.7]] + + else: + ids = [3, 0] + values = [[0.11, 0.21, 0.31, 1.0], [0.41, 0.51, 0.61, 2.0]] + + ids = torch.tensor(ids, dtype=torch.int, device=ctx.device) + values = torch.tensor(values, dtype=torch.float, device=ctx.device) + + sharded_tensor_pool.update( + ids=ids, + values=values, + ) + + lookup_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int, device=ctx.device) + + values = sharded_tensor_pool.lookup(ids=lookup_ids).wait() + torch.testing.assert_close( + values.cpu(), + torch.tensor( + [ + [0.41, 0.51, 0.61, 2.0], + [0.4, 0.5, 0.6, 0.7], + [0.0, 0.0, 0.0, 0.0], + [0.11, 0.21, 0.31, 1.0], + ], + device=torch.device("cpu"), + ), + ) + + state_dict = sharded_tensor_pool.state_dict() + ut = unittest.TestCase() + ut.assertIn("_pool", state_dict) + sharded_pool_state = state_dict["_pool"] + ut.assertIsInstance(sharded_pool_state, ShardedTensor) + pool_state = ( + torch.empty(size=sharded_pool_state.size(), device=ctx.device) + if ctx.rank == 0 + else None + ) + sharded_pool_state.gather(out=pool_state) + if ctx.rank == 0: + torch.testing.assert_close( + pool_state, + torch.tensor( + [ + [0.41, 0.51, 0.61, 2.0], + [0.4, 0.5, 0.6, 0.7], + [0.0, 0.0, 0.0, 0.0], + [0.11, 0.21, 0.31, 1.0], + [0.1000, 0.2000, 0.3000, 0.4000], + ], + device=ctx.device, + ), + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + enable_uvm=st.booleans(), + ) + @settings(deadline=None) + def test_sharded_tensor_pool(self, enable_uvm: bool) -> None: + world_size = 2 + self._run_multi_process_test( + callable=self._test_sharded_tensor_pool, + world_size=world_size, + enable_uvm=enable_uvm, + ) + + @staticmethod + def _test_sharded_tensor_pool_conflict_update( + rank: int, + world_size: int, + ) -> None: + + pool_size = 5 + dim = 3 + backend = "nccl" + dtype = torch.float32 + sharding_plan = ObjectPoolShardingPlan( + sharding_type=ObjectPoolShardingType.ROW_WISE + ) + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + torch.use_deterministic_algorithms(False) + tensor_pool = TensorPool( + pool_size=pool_size, + dim=dim, + dtype=dtype, + ) + + sharded_tensor_pool = TensorPoolSharder().shard( + module=tensor_pool, + plan=sharding_plan, + device=ctx.device, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + ) + + if ctx.rank == 0: + ids = [4, 1] + values = [0.1, 0.2, 0.3], [0.4, 0.5, 0.6] + + else: + ids = [3, 1] + values = [0.11, 0.21, 0.31], [0.41, 0.51, 0.61] + + ids = torch.tensor(ids, dtype=torch.int, device=ctx.device) + values = torch.tensor(values, dtype=torch.float, device=ctx.device) + + sharded_tensor_pool.update( + ids=ids, + values=values, + ) + + lookup_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int, device=ctx.device) + + values = sharded_tensor_pool(ids=lookup_ids).wait() + torch.testing.assert_close( + values.cpu(), + torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.41, 0.51, 0.61], + [0.0, 0.0, 0.0], + [0.11, 0.21, 0.31], + ], + device=torch.device("cpu"), + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharded_tensor_pool_conflict_update( + self, + ) -> None: + world_size = 2 + self._run_multi_process_test( + callable=self._test_sharded_tensor_pool_conflict_update, + world_size=world_size, + ) diff --git a/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py b/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py new file mode 100644 index 000000000..0a6992c0d --- /dev/null +++ b/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from torchrec.distributed.sharding.rw_tensor_pool_sharding import TensorPoolRwSharding +from torchrec.distributed.tensor_sharding import TensorPoolRwShardingContext +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ShardingEnv + + +class TestTensorPoolRwSharding(MultiProcessTestBase): + @staticmethod + def _test_update( + rank: int, + world_size: int, + ) -> None: + backend = "nccl" + dtype = torch.float32 + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + sharding_env = ShardingEnv.from_process_group(ctx.pg) + if ctx.rank == 0: + ids = [4, 1] + values = [0.1, 0.2, 0.3], [0.4, 0.5, 0.6] + + else: + ids = [3, 0] + values = [0.11, 0.21, 0.31], [0.41, 0.51, 0.61] + + ids = torch.tensor(ids, dtype=torch.int, device=ctx.device) + values = torch.tensor(values, dtype=torch.float, device=ctx.device) + + block_size = torch.tensor([3], dtype=torch.int, device=ctx.device) + update_ctx = TensorPoolRwShardingContext(block_size=block_size) + rw_sharding = TensorPoolRwSharding( + env=sharding_env, device=ctx.device, dim=3, pool_size=4 + ) + input_dist = rw_sharding.create_lookup_ids_dist() + update_values_dist = rw_sharding.create_update_values_dist() + dist_ids = input_dist(ctx=update_ctx, ids=ids).wait().wait() + + torch.testing.assert_close( + dist_ids.cpu(), + torch.tensor( + [1, 0], + device=torch.device("cpu"), + dtype=torch.int, + ), + ) + + dist_values = update_values_dist(ctx=update_ctx, values=values).wait() + if rank == 0: + torch.testing.assert_close( + dist_values.cpu(), + torch.tensor( + [[0.4, 0.5, 0.6], [0.41, 0.51, 0.61]], + device=torch.device("cpu"), + dtype=dtype, + ), + ) + else: + torch.testing.assert_close( + dist_values.cpu(), + torch.tensor( + [[0.1, 0.2, 0.3], [0.11, 0.21, 0.31]], + device=torch.device("cpu"), + dtype=dtype, + ), + ) + + @staticmethod + def _test_lookup( + rank: int, + world_size: int, + ) -> None: + backend = "nccl" + dtype = torch.float32 + with MultiProcessContext( + rank, world_size, backend, local_size=world_size + ) as ctx: + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + sharding_env = ShardingEnv.from_process_group(ctx.pg) + + block_size = torch.tensor([3], dtype=torch.int, device=ctx.device) + lookup_ctx = TensorPoolRwShardingContext(block_size=block_size) + rw_sharding = TensorPoolRwSharding( + env=sharding_env, device=ctx.device, dim=3, pool_size=5 + ) + input_dist = rw_sharding.create_lookup_ids_dist() + lookup_values_dist = rw_sharding.create_lookup_values_dist() + + ids = torch.tensor([0, 1, 2, 3], dtype=torch.int, device=ctx.device) + dist_ids = input_dist(ctx=lookup_ctx, ids=ids).wait().wait() + if rank == 0: + torch.testing.assert_close( + dist_ids.cpu(), + torch.tensor( + [0, 1, 2, 0, 1, 2], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + else: + torch.testing.assert_close( + dist_ids.cpu(), + torch.tensor( + [0, 0], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + # assume the _local_pool on rank 0 is + # [ + # [0.41, 0.51, 0.61], + # [0.4, 0.5, 0.6], + # [0.0, 0.0, 0.0], + # ] + + # on rank 1 is + # [ + # [0.11, 0.21, 0.31], + # [0.1, 0.2, 0.3], + # ] + + if rank == 0: + lookup_values = torch.tensor( + [ + [0.41, 0.51, 0.61], + [0.4, 0.5, 0.6], + [0.0, 0.0, 0.0], + [0.41, 0.51, 0.61], + [0.4, 0.5, 0.6], + [0.0, 0.0, 0.0], + ], + dtype=dtype, + device=ctx.device, + ) + + else: + lookup_values = torch.tensor( + [ + [0.11, 0.21, 0.31], + [0.11, 0.21, 0.31], + ], + dtype=dtype, + device=ctx.device, + ) + + dist_output_values = lookup_values_dist( + ctx=lookup_ctx, values=lookup_values + ).wait() + + torch.testing.assert_close( + dist_output_values.cpu(), + torch.tensor( + [ + [0.41, 0.51, 0.61], + [0.4, 0.5, 0.6], + [0.0, 0.0, 0.0], + [0.11, 0.21, 0.31], + ], + device=torch.device("cpu"), + ), + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_update( + self, + ) -> None: + world_size = 2 + self._run_multi_process_test(callable=self._test_update, world_size=world_size) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_lookup( + self, + ) -> None: + world_size = 2 + self._run_multi_process_test(callable=self._test_lookup, world_size=world_size) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py deleted file mode 100644 index 2f15d2114..000000000 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ /dev/null @@ -1,371 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -from dataclasses import dataclass -from typing import cast, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist -from torch import nn, optim -from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - SparseFeaturesList, -) -from torchrec.distributed.embeddingbag import ( - EmbeddingBagCollectionContext, - EmbeddingBagCollectionSharder, - ShardedEmbeddingBagCollection, -) -from torchrec.distributed.test_utils.test_model import ( - ModelInput, - TestEBCSharder, - TestSparseNN, -) -from torchrec.distributed.test_utils.test_sharding import copy_state_dict -from torchrec.distributed.train_pipeline import ( - TrainPipelineBase, - TrainPipelineSparseDist, -) -from torchrec.distributed.types import ( - Awaitable, - ModuleSharder, - ParameterSharding, - ShardingEnv, - ShardingType, -) -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection - -from torchrec.optim.keyed import KeyedOptimizerWrapper -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.streamable import Pipelineable -from torchrec.test_utils import get_free_port, init_distributed_single_host - - -class TestShardedEmbeddingBagCollection(ShardedEmbeddingBagCollection): - def input_dist( - self, - ctx: EmbeddingBagCollectionContext, - features: KeyedJaggedTensor, - ) -> Awaitable[SparseFeaturesList]: - return super().input_dist(ctx, features) - - -class TestCustomEBCSharder(EmbeddingBagCollectionSharder): - def shard( - self, - module: EmbeddingBagCollection, - params: Dict[str, ParameterSharding], - env: ShardingEnv, - device: Optional[torch.device] = None, - ) -> TestShardedEmbeddingBagCollection: - return TestShardedEmbeddingBagCollection( - module, params, env, self.fused_params, device - ) - - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ - ShardingType.ROW_WISE.value, - ] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [EmbeddingComputeKernel.DENSE.value] - - -@dataclass -class ModelInputSimple(Pipelineable): - float_features: torch.Tensor - label: torch.Tensor - - def to(self, device: torch.device, non_blocking: bool) -> "ModelInputSimple": - return ModelInputSimple( - float_features=self.float_features.to( - device=device, non_blocking=non_blocking - ), - label=self.label.to(device=device, non_blocking=non_blocking), - ) - - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. - self.float_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. - self.label.record_stream(stream) - - -class TestModule(nn.Module): - def __init__(self) -> None: - super().__init__() - self.model = nn.Linear(10, 1) - self.loss_fn = nn.BCEWithLogitsLoss() - - def forward( - self, model_input: ModelInputSimple - ) -> Tuple[torch.Tensor, torch.Tensor]: - pred = self.model(model_input.float_features) - loss = self.loss_fn(pred, model_input.label) - return (loss, pred) - - -class TrainPipelineBaseTest(unittest.TestCase): - def setUp(self) -> None: - self.device = torch.device("cuda:0") - torch.backends.cudnn.allow_tf32 = False - torch.backends.cuda.matmul.allow_tf32 = False - - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - def test_equal_to_non_pipelined(self) -> None: - model_cpu = TestModule() - model_gpu = TestModule().to(self.device) - model_gpu.load_state_dict(model_cpu.state_dict()) - optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) - optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [ - ModelInputSimple( - float_features=torch.rand((10,)), - label=torch.randint(2, (1,), dtype=torch.float32), - ) - for b in range(5) - ] - dataloader = iter(data) - pipeline = TrainPipelineBase(model_gpu, optimizer_gpu, self.device) - - for example in data[:-1]: - optimizer_cpu.zero_grad() - loss, pred = model_cpu(example) - loss.backward() - optimizer_cpu.step() - - pred_gpu = pipeline.progress(dataloader) - - self.assertEqual(pred_gpu.device, self.device) - self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) - - -class TrainPipelineSparseDistTest(unittest.TestCase): - def setUp(self) -> None: - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - self.pg = init_distributed_single_host(backend="gloo", rank=0, world_size=1) - - num_features = 4 - num_weighted_features = 2 - - self.tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 100, - embedding_dim=(i + 1) * 4, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_features) - ] - self.weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 100, - embedding_dim=(i + 1) * 4, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - - self.device = torch.device("cuda:0") - - def tearDown(self) -> None: - super().tearDown() - dist.destroy_process_group(self.pg) - - def _test_feature_processor_helper( - self, - unsharded_model: TestSparseNN, - distributed_model: DistributedModelParallel, - fp_tables: List[EmbeddingBagConfig], - ) -> None: - copy_state_dict(unsharded_model.state_dict(), distributed_model.state_dict()) - optimizer_cpu = optim.SGD(unsharded_model.parameters(), lr=0.1) - optimizer_distributed = KeyedOptimizerWrapper( - dict(distributed_model.named_parameters()), - lambda params: optim.SGD(params, lr=0.1), - ) - pipeline = TrainPipelineSparseDist( - distributed_model, optimizer_distributed, self.device - ) - - data = [ - ModelInput.generate( - tables=self.tables + fp_tables, - weighted_tables=self.weighted_tables, - batch_size=1, - world_size=1, - num_float_features=10, - )[0] - for i in range(5) - ] - dataloader = iter(data) - - for example in data[:-2]: - optimizer_cpu.zero_grad() - loss, pred = unsharded_model(example) - example.idlist_features._jt_dict = None - example.idscore_features._jt_dict = None - loss.backward() - optimizer_cpu.step() - pred_gpu = pipeline.progress(dataloader) - - self.assertEqual(pred_gpu.device, self.device) - self.assertEqual(pred_gpu.cpu().size(), pred.size()) - torch.testing.assert_close(pred_gpu.cpu(), pred) - self.assertEqual(len(pipeline._pipelined_modules), 3) - - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - def test_position_weighted_feature_processor(self) -> None: - max_feature_length = 100 - table_num = 2 - fp_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 100, - embedding_dim=(i + 1) * 4, - name="fp_table_" + str(i), - feature_names=["fp_feature_" + str(i)], - need_pos=True, - ) - for i in range(table_num) - ] - # chained feature_processors, the output is only 1 feature - max_feature_lengths_list = [ - { - name: max_feature_length - for table in reversed(fp_tables) - for name in table.feature_names - } - for i in range(table_num) - ] - - unsharded_model = TestSparseNN( - tables=self.tables + fp_tables, - weighted_tables=self.weighted_tables, - dense_device=self.device, - sparse_device=torch.device("meta"), - max_feature_lengths_list=max_feature_lengths_list, - ) - distributed_model = DistributedModelParallel( - unsharded_model, - env=ShardingEnv.from_process_group(self.pg), - init_data_parallel=True, - device=self.device, - sharders=[cast(ModuleSharder[nn.Module], TestCustomEBCSharder())], - ) - test_unsharded_model = TestSparseNN( - tables=self.tables + fp_tables, - weighted_tables=self.weighted_tables, - max_feature_lengths_list=max_feature_lengths_list, - ) - self._test_feature_processor_helper( - test_unsharded_model, distributed_model, fp_tables - ) - - def _test_move_cpu_gpu_helper( - self, distributed_model: DistributedModelParallel - ) -> None: - model_cpu = TestSparseNN( - tables=self.tables, weighted_tables=self.weighted_tables - ) - optimizer_cpu = optim.SGD(model_cpu.parameters(), lr=0.1) - optimizer_distributed = KeyedOptimizerWrapper( - dict(distributed_model.named_parameters()), - lambda params: optim.SGD(params, lr=0.1), - ) - pipeline = TrainPipelineSparseDist( - distributed_model, optimizer_distributed, self.device - ) - - data = [ - ModelInput.generate( - tables=self.tables, - weighted_tables=self.weighted_tables, - batch_size=1, - world_size=1, - num_float_features=10, - )[0] - for i in range(5) - ] - dataloader = iter(data) - - for example in data[:-2]: - optimizer_cpu.zero_grad() - loss, pred = model_cpu(example) - loss.backward() - optimizer_cpu.step() - - pred_gpu = pipeline.progress(dataloader) - - self.assertEqual(pred_gpu.device, self.device) - self.assertEqual(pred_gpu.cpu().size(), pred.size()) - self.assertEqual(len(pipeline._pipelined_modules), 2) - - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - def test_move_cpu_gpu(self) -> None: - unsharded_model = TestSparseNN( - tables=self.tables, - weighted_tables=self.weighted_tables, - dense_device=self.device, - sparse_device=torch.device("meta"), - ) - distributed_model = DistributedModelParallel( - unsharded_model, - env=ShardingEnv.from_process_group(self.pg), - init_data_parallel=False, - device=self.device, - sharders=[ - cast( - ModuleSharder[nn.Module], - TestEBCSharder( - sharding_type=ShardingType.TABLE_WISE.value, - kernel_type=EmbeddingComputeKernel.DENSE.value, - ), - ) - ], - ) - self._test_move_cpu_gpu_helper(distributed_model) - - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", - ) - def test_pipelining(self) -> None: - unsharded_model = TestSparseNN( - tables=self.tables, - weighted_tables=self.weighted_tables, - dense_device=self.device, - sparse_device=torch.device("meta"), - ) - distributed_model = DistributedModelParallel( - unsharded_model, - env=ShardingEnv.from_process_group(self.pg), - init_data_parallel=False, - device=self.device, - sharders=[cast(ModuleSharder[nn.Module], TestCustomEBCSharder())], - ) - self._test_move_cpu_gpu_helper(distributed_model) diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index bfddf9a66..d25c213e8 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import itertools import math import os @@ -12,7 +14,6 @@ import unittest from typing import cast, List, Optional, Tuple -import numpy as np import torch import torch.distributed as dist from hypothesis import given, settings, strategies as st, Verbosity @@ -20,8 +21,24 @@ from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder -from torchrec.distributed.utils import get_unsharded_module_names, merge_fused_params +from torchrec.distributed.types import ( + BoundsCheckMode, + CacheAlgorithm, + CacheParams, + DataType, + ModuleSharder, + MultiPassPrefetchConfig, + ParameterSharding, + ShardingBucketMetadata, + ShardMetadata, +) +from torchrec.distributed.utils import ( + add_params_from_parameter_sharding, + convert_to_fbgemm_types, + get_bucket_metadata_from_shard_metadata, + get_unsharded_module_names, + merge_fused_params, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.sparse.test_utils import keyed_jagged_tensor_equals @@ -71,7 +88,7 @@ def test_get_unsharded_module_names(self) -> None: ], ) - np.testing.assert_array_equal( + self.assertListEqual( sorted(get_unsharded_module_names(dmp)), sorted(["_dmp_wrapped_module.over", "_dmp_wrapped_module.dense"]), ) @@ -120,7 +137,7 @@ def _compute_translated_indices_with_weights( ) -> List[Tuple[int, int]]: translated_indices_with_weights = [(0, 0)] * len(row_indices) - translated_indices_offsets = np.cumsum([0] + translated_lengths) + translated_indices_offsets = list(itertools.accumulate([0] + translated_lengths)) batch_size = int(lengths_size / len(block_sizes)) iteration = feature_offset = batch_iteration = 0 for start_offset, end_offset in zip(indices_offsets, indices_offsets[1:]): @@ -162,6 +179,7 @@ def block_bucketize_ref( keyed_jagged_tensor: KeyedJaggedTensor, trainers_size: int, block_sizes: torch.Tensor, + device: str = "cuda", ) -> KeyedJaggedTensor: lengths_list = keyed_jagged_tensor.lengths().view(-1).tolist() indices_list = keyed_jagged_tensor.values().view(-1).tolist() @@ -187,7 +205,7 @@ def block_bucketize_ref( elements in indices_list[4:6] belongs to feature 1 batch 0 elements in indices_list[6:6] belongs to feature 1 batch 1 """ - indices_offsets = np.cumsum([0] + lengths_list) + indices_offsets = list(itertools.accumulate([0] + lengths_list)) translated_lengths = _compute_translated_lengths( row_indices=indices_list, @@ -217,28 +235,37 @@ def block_bucketize_ref( expected_keys = [ key for index in range(trainers_size) for key in keyed_jagged_tensor.keys() ] - - return KeyedJaggedTensor( - keys=expected_keys, - lengths=torch.tensor( - translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + if device == "cuda": + return KeyedJaggedTensor( + keys=expected_keys, + lengths=torch.tensor( + translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + ) + .view(-1) + .cuda(), + values=torch.tensor( + translated_indices, dtype=keyed_jagged_tensor.values().dtype + ).cuda(), + weights=( + torch.tensor(translated_weights).float().cuda() + if weights_list + else None + ), + ) + else: + return KeyedJaggedTensor( + keys=expected_keys, + lengths=torch.tensor( + translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + ).view(-1), + values=torch.tensor( + translated_indices, dtype=keyed_jagged_tensor.values().dtype + ), + weights=torch.tensor(translated_weights).float() if weights_list else None, ) - .view(-1) - .cuda(), - values=torch.tensor( - translated_indices, dtype=keyed_jagged_tensor.values().dtype - ).cuda(), - weights=torch.tensor(translated_weights).float().cuda() - if weights_list - else None, - ) class KJTBucketizeTest(unittest.TestCase): - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) # pyre-ignore[56] @given( index_type=st.sampled_from([torch.int, torch.long]), @@ -246,8 +273,12 @@ class KJTBucketizeTest(unittest.TestCase): world_size=st.integers(1, 129), num_features=st.integers(1, 15), batch_size=st.integers(1, 15), + variable_bucket_pos=st.booleans(), + device=st.sampled_from( + ["cpu"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) def test_kjt_bucketize_before_all2all( self, index_type: torch.dtype, @@ -255,6 +286,8 @@ def test_kjt_bucketize_before_all2all( world_size: int, num_features: int, batch_size: int, + variable_bucket_pos: bool, + device: str, ) -> None: MAX_BATCH_SIZE = 15 MAX_LENGTH = 10 @@ -285,39 +318,56 @@ def test_kjt_bucketize_before_all2all( # for each feature, calculate the minimum block size needed to # distribute all rows to the available trainers block_sizes_list = [ - math.ceil((max(feature_indices_list) + 1) / world_size) - if feature_indices_list - else 1 + ( + math.ceil((max(feature_indices_list) + 1) / world_size) + if feature_indices_list + else 1 + ) for feature_indices_list in indices_lists ] + block_bucketize_row_pos = [] if variable_bucket_pos else None + if variable_bucket_pos: + for block_size in block_sizes_list: + # pyre-ignore + block_bucketize_row_pos.append( + torch.tensor( + [w * block_size for w in range(world_size + 1)], + dtype=index_type, + ) + ) kjt = KeyedJaggedTensor( keys=keys_list, - lengths=torch.tensor(lengths_list, dtype=offset_type) - .view(num_features * batch_size) - .cuda(), - values=torch.tensor(indices_list, dtype=index_type).cuda(), - weights=torch.tensor(weights_list, dtype=torch.float).cuda(), + lengths=torch.tensor(lengths_list, dtype=offset_type, device=device).view( + num_features * batch_size + ), + values=torch.tensor(indices_list, dtype=index_type, device=device), + weights=torch.tensor(weights_list, dtype=torch.float, device=device), ) """ each entry in block_sizes identifies how many hashes for each feature goes to every rank; we have three featues in `self.features` """ - block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() - + block_sizes = torch.tensor(block_sizes_list, dtype=index_type, device=device) block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt, world_size, block_sizes, False, False + kjt=kjt, + num_buckets=world_size, + block_sizes=block_sizes, + block_bucketize_row_pos=block_bucketize_row_pos, ) expected_block_bucketized_kjt = block_bucketize_ref( kjt, world_size, block_sizes, + device, ) self.assertTrue( keyed_jagged_tensor_equals( - block_bucketized_kjt, expected_block_bucketized_kjt + block_bucketized_kjt, + expected_block_bucketized_kjt, + is_pooled_features=True, ) ) @@ -340,3 +390,243 @@ def test_merge_fused_params_update(self) -> None: ) self.assertFalse(configured_fused_params is None) self.assertEqual(configured_fused_params, {"learning_rate": 0.0}) + + +class AddParamsFromParameterShardingTest(unittest.TestCase): + def setUp(self) -> None: + self.parameter_sharding = ParameterSharding( + sharding_type="data_parallel", + compute_kernel="dense", + ranks=[0, 1], + sharding_spec=None, + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + reserved_memory=1.0, + prefetch_pipeline=False, + multipass_prefetch_config=MultiPassPrefetchConfig(num_passes=2), + ), + enforce_hbm=False, + stochastic_rounding=True, + bounds_check_mode=BoundsCheckMode.WARNING, + ) + + def test_add_params_from_parameter_sharding(self) -> None: + fused_params = None + fused_params = add_params_from_parameter_sharding( + fused_params, self.parameter_sharding + ) + expected_fused_params = { + "cache_algorithm": CacheAlgorithm.LFU, + "cache_reserved_memory": 1.0, + "prefetch_pipeline": False, + "enforce_hbm": False, + "stochastic_rounding": True, + "bounds_check_mode": BoundsCheckMode.WARNING, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2), + } + self.assertEqual(fused_params, expected_fused_params) + + def test_add_params_from_parameter_sharding_override(self) -> None: + fused_params = { + "learning_rate": 0.1, + "cache_algorithm": CacheAlgorithm.LRU, + "stochastic_rounding": False, + "prefetch_pipeline": True, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=5), + } + fused_params = add_params_from_parameter_sharding( + fused_params, self.parameter_sharding + ) + expected_fused_params = { + "learning_rate": 0.1, + "cache_algorithm": CacheAlgorithm.LFU, + "cache_reserved_memory": 1.0, + "prefetch_pipeline": False, + "enforce_hbm": False, + "stochastic_rounding": True, + "bounds_check_mode": BoundsCheckMode.WARNING, + "multipass_prefetch_config": MultiPassPrefetchConfig(num_passes=2), + } + self.assertEqual(fused_params, expected_fused_params) + + +class ConvertFusedParamsTest(unittest.TestCase): + def test_convert_to_fbgemm_types(self) -> None: + per_table_fused_params = { + "cache_precision": DataType.FP32, + "weights_precision": DataType.FP32, + "output_dtype": DataType.FP32, + } + self.assertTrue(isinstance(per_table_fused_params["cache_precision"], DataType)) + self.assertTrue( + isinstance(per_table_fused_params["weights_precision"], DataType) + ) + self.assertTrue(isinstance(per_table_fused_params["output_dtype"], DataType)) + + per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) + self.assertFalse( + isinstance(per_table_fused_params["cache_precision"], DataType) + ) + self.assertFalse( + isinstance(per_table_fused_params["weights_precision"], DataType) + ) + self.assertFalse(isinstance(per_table_fused_params["output_dtype"], DataType)) + + +class TestBucketMetadata(unittest.TestCase): + def test_bucket_metadata(self) -> None: + # Given no shards + # When we get bucket metadata from get_bucket_metadata_from_shard_metadata + # Then an error should be raised + self.assertRaisesRegex( + AssertionError, + "Shards cannot be empty", + get_bucket_metadata_from_shard_metadata, + [], + num_buckets=4, + ) + + # Given 1 shard and 5 buckets + shards = [ + ShardMetadata(shard_offsets=[0], shard_sizes=[5], placement="rank:0/cuda:0") + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + bucket_metadata = get_bucket_metadata_from_shard_metadata(shards, num_buckets=5) + # Then we should get 1 offset with value 0 + expected_metadata = ShardingBucketMetadata( + num_buckets_per_shard=[5], bucket_offsets_per_shard=[0], bucket_size=1 + ) + self.assertEqual(bucket_metadata, expected_metadata) + + # Given 2 shards of size 5 and 4 buckets + shards = [ + ShardMetadata( + shard_offsets=[0], shard_sizes=[5], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[5], shard_sizes=[5], placement="rank:0/cuda:0" + ), + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + # Then an error should be raised + self.assertRaisesRegex( + AssertionError, + "Table size '10' must be divisible by num_buckets '4'", + get_bucket_metadata_from_shard_metadata, + shards, + num_buckets=4, + ) + + # Given 2 shards of size 2 and 5 buckets + shards = [ + ShardMetadata( + shard_offsets=[0], shard_sizes=[2], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[2], shard_sizes=[2], placement="rank:0/cuda:0" + ), + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + # Then an error should be raised + self.assertRaisesRegex( + AssertionError, + "Table size '4' must be divisible by num_buckets '5'", + get_bucket_metadata_from_shard_metadata, + shards, + num_buckets=5, + ) + + # Given 2 shards sharded by column + shards = [ + ShardMetadata( + shard_offsets=[0, 0], shard_sizes=[20, 5], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[0, 5], shard_sizes=[20, 5], placement="rank:0/cuda:0" + ), + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + # Then an error should be raised + self.assertRaisesRegex( + AssertionError, + r"Shard shard_offsets\[1\] '5' is not 0. Table should be only row-wise sharded for bucketization", + get_bucket_metadata_from_shard_metadata, + shards, + num_buckets=2, + ) + + # Given 2 shards of size 10 and 5 buckets + shards = [ + ShardMetadata( + shard_offsets=[0], shard_sizes=[10], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[10], shard_sizes=[10], placement="rank:0/cuda:0" + ), + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + # Then an error should be raised + self.assertRaisesRegex( + AssertionError, + r"Shard size\[0\] '10' is not divisible by bucket size '4'", + get_bucket_metadata_from_shard_metadata, + shards, + num_buckets=5, + ) + + # Given 2 shards of size 20 and 10 buckets + shards = [ + ShardMetadata( + shard_offsets=[0], shard_sizes=[20], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[20], shard_sizes=[20], placement="rank:0/cuda:0" + ), + ] + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + bucket_metadata = get_bucket_metadata_from_shard_metadata( + shards, + num_buckets=10, + ) + # Then num_buckets_per_shard should be set to [5, 5] + self.assertEqual( + bucket_metadata, + ShardingBucketMetadata( + num_buckets_per_shard=[5, 5], + bucket_offsets_per_shard=[0, 5], + bucket_size=4, + ), + ) + + # Given 3 uneven shards of sizes 12, 16 and 20 and 12 buckets + shards = [ + ShardMetadata( + shard_offsets=[0, 0], shard_sizes=[12, 0], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[12, 0], shard_sizes=[16, 0], placement="rank:0/cuda:0" + ), + ShardMetadata( + shard_offsets=[28, 0], shard_sizes=[20, 0], placement="rank:0/cuda:0" + ), + ] + + # When we get bucket offsets from get_bucket_metadata_from_shard_metadata + bucket_metadata = get_bucket_metadata_from_shard_metadata( + shards, + num_buckets=12, + ) + # Then num_buckets_per_shard should be set to [3, 4, 5] + self.assertEqual( + bucket_metadata, + ShardingBucketMetadata( + num_buckets_per_shard=[3, 4, 5], + bucket_offsets_per_shard=[0, 3, 7], + bucket_size=4, + ), + ) diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py deleted file mode 100644 index c30c3f8b9..000000000 --- a/torchrec/distributed/train_pipeline.py +++ /dev/null @@ -1,573 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import abc -import logging -from dataclasses import dataclass, field -from typing import ( - Any, - cast, - Dict, - Generic, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) - -import torch -from torch.autograd.profiler import record_function -from torch.fx.node import Node -from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule -from torchrec.distributed.types import Awaitable -from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor -from torchrec.streamable import Multistreamable, Pipelineable - -logger: logging.Logger = logging.getLogger(__name__) - - -In = TypeVar("In", bound=Pipelineable) -Out = TypeVar("Out") - - -class TrainPipeline(abc.ABC, Generic[In, Out]): - @abc.abstractmethod - def progress(self, dataloader_iter: Iterator[In]) -> Out: - pass - - -def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: - assert isinstance( - batch, (torch.Tensor, Pipelineable) - ), f"{type(batch)} must implement Pipelineable interface" - return cast(In, batch.to(device=device, non_blocking=non_blocking)) - - -def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: - if stream is None: - return - torch.cuda.current_stream().wait_stream(stream) - # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, - # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is - # freed, its memory is likely to be reused by newly constructed tenosrs. By default, - # this allocator traces whether a tensor is still in use by only the CUDA stream where it - # was created. When a tensor is used by additional CUDA streams, we need to call record_stream - # to tell the allocator about all these streams. Otherwise, the allocator might free the - # underlying memory of the tensor once it is no longer used by the creator stream. This is - # a notable programming trick when we write programs using multi CUDA streams. - cur_stream = torch.cuda.current_stream() - assert isinstance( - batch, (torch.Tensor, Multistreamable) - ), f"{type(batch)} must implement Multistreamable interface" - batch.record_stream(cur_stream) - - -class TrainPipelineBase(TrainPipeline[In, Out]): - """ - This class runs training iterations using a pipeline of two stages, each as a CUDA - stream, namely, the current (default) stream and `self._memcpy_stream`. For each - iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU - memory, and the default stream runs forward, backward, and optimization. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - ) -> None: - self._model = model - self._optimizer = optimizer - self._device = device - self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( - torch.cuda.Stream() if device.type == "cuda" else None - ) - self._cur_batch: Optional[In] = None - self._connected = False - - def _connect(self, dataloader_iter: Iterator[In]) -> None: - cur_batch = next(dataloader_iter) - self._cur_batch = cur_batch - with torch.cuda.stream(self._memcpy_stream): - self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) - self._connected = True - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - if not self._connected: - self._connect(dataloader_iter) - - # Fetch next batch - with record_function("## next_batch ##"): - next_batch = next(dataloader_iter) - cur_batch = self._cur_batch - assert cur_batch is not None - - if self._model.training: - with record_function("## zero_grad ##"): - self._optimizer.zero_grad() - - with record_function("## wait_for_batch ##"): - _wait_for_batch(cur_batch, self._memcpy_stream) - - with record_function("## forward ##"): - (losses, output) = self._model(cur_batch) - - if self._model.training: - with record_function("## backward ##"): - torch.sum(losses, dim=0).backward() - - # Copy the next batch to GPU - self._cur_batch = cur_batch = next_batch - with record_function("## copy_batch_to_gpu ##"): - with torch.cuda.stream(self._memcpy_stream): - self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) - - # Update - if self._model.training: - with record_function("## optimizer ##"): - self._optimizer.step() - - return output - - -class Tracer(torch.fx.Tracer): - # Disable proxying buffers during tracing. Ideally, proxying buffers would - # be disabled, but some models are currently mutating buffer values, which - # causes errors during tracing. If those models can be rewritten to not do - # that, we can likely remove this line - proxy_buffer_attributes = False - - def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: - super().__init__() - self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] - - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: - return True - return super().is_leaf_module(m, module_qualified_name) - - -@dataclass -class TrainPipelineContext: - # pyre-ignore [4] - input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) - module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) - # pyre-ignore [4] - feature_processor_forwards: List[Any] = field(default_factory=list) - - -@dataclass -class ArgInfo: - # attributes of input batch, e.g. batch.attr1.attr2 call - # will produce ["attr1", "attr2"] - input_attrs: List[str] - # batch[attr1].attr2 will produce [True, False] - is_getitems: List[bool] - # name for kwarg of pipelined forward() call or None - # for a positional arg - name: Optional[str] - - -class PipelinedForward: - def __init__( - self, - name: str, - args: List[ArgInfo], - module: ShardedModule, - context: TrainPipelineContext, - dist_stream: Optional[torch.cuda.streams.Stream], - ) -> None: - self._name = name - self._args = args - self._module = module - self._context = context - self._dist_stream = dist_stream - - # pyre-ignore [2, 24] - def __call__(self, *input, **kwargs) -> Awaitable: - assert self._name in self._context.input_dist_requests - request = self._context.input_dist_requests[self._name] - assert isinstance(request, Awaitable) - with record_function("## wait_sparse_data_dist ##"): - # Finish waiting on the dist_stream, - # in case some delayed stream scheduling happens during the wait() call. - with torch.cuda.stream(self._dist_stream): - data = request.wait() - - # Make sure that both result of input_dist and context - # are properly transferred to the current stream. - if self._dist_stream is not None: - torch.cuda.current_stream().wait_stream(self._dist_stream) - cur_stream = torch.cuda.current_stream() - - assert isinstance( - data, (torch.Tensor, Multistreamable) - ), f"{type(data)} must implement Multistreamable interface" - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. - data.record_stream(cur_stream) - - ctx = self._context.module_contexts[self._name] - ctx.record_stream(cur_stream) - - if len(self._context.feature_processor_forwards) > 0: - with record_function("## feature_processor ##"): - for sparse_feature in data: - if sparse_feature.id_score_list_features is not None: - for fp_forward in self._context.feature_processor_forwards: - sparse_feature.id_score_list_features = fp_forward( - sparse_feature.id_score_list_features - ) - - return self._module.compute_and_output_dist( - self._context.module_contexts[self._name], data - ) - - @property - def name(self) -> str: - return self._name - - @property - def args(self) -> List[ArgInfo]: - return self._args - - -def _start_data_dist( - pipelined_modules: List[ShardedModule], - batch: In, - context: TrainPipelineContext, -) -> None: - context.input_dist_requests.clear() - context.module_contexts.clear() - for module in pipelined_modules: - forward = module.forward - assert isinstance(forward, PipelinedForward) - - # Retrieve argument for the input_dist of EBC - # is_getitem True means this argument could be retrieved by a list - # False means this argument is getting while getattr - # and this info was done in the _rewrite_model by tracing the - # entire model to get the arg_info_list - args = [] - kwargs = {} - for arg_info in forward.args: - if arg_info.input_attrs: - arg = batch - for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems): - if is_getitem: - arg = arg[attr] - else: - arg = getattr(arg, attr) - if arg_info.name: - kwargs[arg_info.name] = arg - else: - args.append(arg) - else: - args.append(None) - # Start input distribution. - module_ctx = module.create_context() - context.module_contexts[forward.name] = module_ctx - context.input_dist_requests[forward.name] = module.input_dist( - module_ctx, *args, **kwargs - ) - - -def _get_node_args_helper( - # pyre-ignore - arguments, - num_found: int, - feature_processor_arguments: Optional[List[Node]] = None, -) -> Tuple[List[ArgInfo], int]: - """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ - - arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] - for arg, arg_info in zip(arguments, arg_info_list): - if arg is None: - num_found += 1 - continue - while True: - if not isinstance(arg, torch.fx.Node): - break - child_node = arg - - if child_node.op == "placeholder": - num_found += 1 - break - # skip this fp node - elif ( - feature_processor_arguments is not None - and child_node in feature_processor_arguments - ): - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "builtins" - # pyre-ignore[16] - and child_node.target.__name__ == "getattr" - ): - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, False) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "_operator" - # pyre-ignore[16] - and child_node.target.__name__ == "getitem" - ): - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, True) - arg = child_node.args[0] - else: - break - return arg_info_list, num_found - - -def _get_node_args( - node: Node, feature_processor_nodes: Optional[List[Node]] = None -) -> Tuple[List[ArgInfo], int]: - num_found = 0 - pos_arg_info_list, num_found = _get_node_args_helper( - node.args, num_found, feature_processor_nodes - ) - kwargs_arg_info_list, num_found = _get_node_args_helper( - node.kwargs.values(), num_found - ) - - # Replace with proper names for kwargs - for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): - arg_info_list.name = name - - arg_info_list = pos_arg_info_list + kwargs_arg_info_list - return arg_info_list, num_found - - -def _get_unsharded_module_names_helper( - model: torch.nn.Module, - path: str, - unsharded_module_names: Set[str], -) -> bool: - sharded_children = set() - for name, child in model.named_children(): - curr_path = path + name - if isinstance(child, ShardedModule): - sharded_children.add(name) - else: - child_sharded = _get_unsharded_module_names_helper( - child, - curr_path + ".", - unsharded_module_names, - ) - if child_sharded: - sharded_children.add(name) - - if len(sharded_children) > 0: - for name, _ in model.named_children(): - if name not in sharded_children: - unsharded_module_names.add(path + name) - - return len(sharded_children) > 0 - - -def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: - """ - Returns a list of top level modules do not contain any sharded sub modules. - """ - - unsharded_module_names: Set[str] = set() - _get_unsharded_module_names_helper( - model, - "", - unsharded_module_names, - ) - return list(unsharded_module_names) - - -def _rewrite_model( # noqa C901 - model: torch.nn.Module, - context: TrainPipelineContext, - dist_stream: Optional[torch.cuda.streams.Stream], -) -> List[ShardedModule]: - - # Get underlying nn.Module - if isinstance(model, DistributedModelParallel): - model = model.module - - # Collect a list of sharded modules. - sharded_modules = {} - fp_modules = {} - for name, m in model.named_modules(): - if isinstance(m, ShardedModule): - sharded_modules[name] = m - if isinstance(m, BaseGroupedFeatureProcessor): - fp_modules[name] = m - - # Trace a model. - tracer = Tracer(leaf_modules=_get_unsharded_module_names(model)) - graph = tracer.trace(model) - - feature_processor_nodes = [] - # find the fp node - for node in graph.nodes: - if node.op == "call_module" and node.target in fp_modules: - feature_processor_nodes.append(node) - # Select sharded modules, which are top-level in the forward call graph, - # i.e. which don't have input transformations, i.e. - # rely only on 'builtins.getattr'. - ret = [] - for node in graph.nodes: - if node.op == "call_module" and node.target in sharded_modules: - total_num_args = len(node.args) + len(node.kwargs) - if total_num_args == 0: - continue - arg_info_list, num_found = _get_node_args(node, feature_processor_nodes) - if num_found == total_num_args: - logger.info(f"Module '{node.target}'' will be pipelined") - child = sharded_modules[node.target] - child.forward = PipelinedForward( - node.target, - arg_info_list, - child, - context, - dist_stream, - ) - ret.append(child) - return ret - - -class TrainPipelineSparseDist(TrainPipeline[In, Out]): - """ - This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward and backward. This helps hide the all2all latency while preserving the - training forward / backward ordering. - - stage 3: forward, backward - uses default CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream - - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. - - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - device: torch.device, - ) -> None: - self._model = model - self._optimizer = optimizer - self._device = device - # use two data streams to support two concurrent batches - if device.type == "cuda": - self._memcpy_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream() - self._data_dist_stream: Optional[ - torch.cuda.streams.Stream - ] = torch.cuda.Stream() - else: - self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None - self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None - self._batch_i: Optional[In] = None - self._batch_ip1: Optional[In] = None - self._batch_ip2: Optional[In] = None - self._connected = False - self._context = TrainPipelineContext() - self._pipelined_modules: List[ShardedModule] = [] - - def _replace_fp_forward(self, model: torch.nn.Module) -> None: - for _, m in model.named_modules(): - if isinstance(m, BaseGroupedFeatureProcessor): - self._context.feature_processor_forwards.append(m.forward) - # pyre-ignore[8]: Incompatible attribute type - m.forward = lambda x: x - - def _connect(self, dataloader_iter: Iterator[In]) -> None: - self._replace_fp_forward(cast(torch.nn.Module, self._model.module)) - # batch 1 - with torch.cuda.stream(self._memcpy_stream): - batch_i = next(dataloader_iter) - self._batch_i = batch_i = _to_device( - batch_i, self._device, non_blocking=True - ) - # Try to pipeline input data dist. - self._pipelined_modules = _rewrite_model( - self._model, self._context, self._data_dist_stream - ) - - with torch.cuda.stream(self._data_dist_stream): - _wait_for_batch(batch_i, self._memcpy_stream) - _start_data_dist(self._pipelined_modules, batch_i, self._context) - - # batch 2 - with torch.cuda.stream(self._memcpy_stream): - batch_ip1 = next(dataloader_iter) - self._batch_ip1 = batch_ip1 = _to_device( - batch_ip1, self._device, non_blocking=True - ) - self._connected = True - - def progress(self, dataloader_iter: Iterator[In]) -> Out: - if not self._connected: - self._connect(dataloader_iter) - - if self._model.training: - with record_function("## zero_grad ##"): - self._optimizer.zero_grad() - - with record_function("## copy_batch_to_gpu ##"): - with torch.cuda.stream(self._memcpy_stream): - batch_ip2 = next(dataloader_iter) - self._batch_ip2 = batch_ip2 = _to_device( - batch_ip2, self._device, non_blocking=True - ) - batch_i = cast(In, self._batch_i) - batch_ip1 = cast(In, self._batch_ip1) - - with record_function("## wait_for_batch ##"): - _wait_for_batch(batch_i, self._data_dist_stream) - - # Forward - with record_function("## forward ##"): - # if using multiple streams (ie. CUDA), create an event in default stream - # before starting forward pass - if self._data_dist_stream: - event = torch.cuda.current_stream().record_event() - (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i)) - - # Data Distribution - with record_function("## sparse_data_dist ##"): - with torch.cuda.stream(self._data_dist_stream): - _wait_for_batch(batch_ip1, self._memcpy_stream) - # Ensure event in default stream has been called before - # starting data dist - if self._data_dist_stream: - # pyre-ignore [61]: Local variable `event` is undefined, or not always defined - self._data_dist_stream.wait_event(event) - _start_data_dist(self._pipelined_modules, batch_ip1, self._context) - - if self._model.training: - # Backward - with record_function("## backward ##"): - torch.sum(losses, dim=0).backward() - - # Update - with record_function("## optimizer ##"): - self._optimizer.step() - - self._batch_i = batch_ip1 - self._batch_ip1 = batch_ip2 - - return output diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py new file mode 100644 index 000000000..d7b38d2b0 --- /dev/null +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa + EvalPipelineSparseDist, # noqa + PrefetchTrainPipelineSparseDist, # noqa + StagedTrainPipeline, # noqa + TorchCompileConfig, # noqa + TrainPipeline, # noqa + TrainPipelineBase, # noqa + TrainPipelinePT2, # noqa + TrainPipelineSparseDist, # noqa + TrainPipelineSparseDistCompAutograd, # noqa +) +from torchrec.distributed.train_pipeline.utils import ( # noqa + _override_input_dist_forwards, # noqa + _rewrite_model, # noqa + _start_data_dist, # noqa + _to_device, # noqa + _wait_for_batch, # noqa + ArgInfo, # noqa + DataLoadingThread, # noqa + In, # noqa + Out, # noqa + SparseDataDistUtil, # noqa + StageOut, # noqa + Tracer, # noqa + TrainPipelineContext, # noqa +) diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py new file mode 100644 index 000000000..7e5f45532 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import copy +import os +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +import click + +import torch +import torch.distributed as dist +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import nn, optim +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + run_multi_process_func, +) +from torchrec.distributed.test_utils.test_input import ModelInput, TdModelInput +from torchrec.distributed.test_utils.test_model import ( + TestEBCSharder, + TestOverArchLarge, + TestSparseNN, +) +from torchrec.distributed.train_pipeline import ( + TrainPipeline, + TrainPipelineBase, + TrainPipelineSparseDist, +) +from torchrec.distributed.train_pipeline.train_pipelines import ( + PrefetchTrainPipelineSparseDist, + TrainPipelineSemiSync, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +_pipeline_cls: Dict[str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]] = { + "base": TrainPipelineBase, + "sparse": TrainPipelineSparseDist, + "semi": TrainPipelineSemiSync, + "prefetch": PrefetchTrainPipelineSparseDist, +} + + +def _gen_pipelines( + pipelines: str, +) -> List[Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]]: + if pipelines == "all": + return list(_pipeline_cls.values()) + else: + return [_pipeline_cls[pipelines]] + + +@click.command() +@click.option( + "--world_size", + type=int, + default=4, + help="Num of GPUs to run", +) +@click.option( + "--n_features", + default=100, + help="Total number of sparse embeddings to be used.", +) +@click.option( + "--ratio_features_weighted", + default=0.4, + help="percentage of features weighted vs unweighted", +) +@click.option( + "--dim_emb", + type=int, + default=512, + help="Dim embeddings embedding.", +) +@click.option( + "--num_batches", + type=int, + default=20, + help="Num of batchs to run.", +) +@click.option( + "--batch_size", + type=int, + default=8192, + help="Batch size.", +) +@click.option( + "--sharding_type", + type=ShardingType, + default=ShardingType.TABLE_WISE, + help="ShardingType.", +) +@click.option( + "--pooling_factor", + type=int, + default=100, + help="Pooling Factor.", +) +@click.option( + "--input_type", + type=str, + default="kjt", + help="Input type: kjt, td", +) +@click.option( + "--pipeline", + type=str, + default="all", + help="Pipeline to run: all, base, sparse, semi, prefetch", +) +@click.option( + "--profile", + type=str, + default="", + help="profile output directory", +) +def main( + world_size: int, + n_features: int, + ratio_features_weighted: float, + dim_emb: int, + num_batches: int, + batch_size: int, + sharding_type: ShardingType, + pooling_factor: int, + input_type: str, + pipeline: str, + profile: str, +) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + + num_weighted_features = int(n_features * ratio_features_weighted) + num_features = n_features - num_weighted_features + + tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=dim_emb, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=dim_emb, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + run_multi_process_func( + func=runner, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=sharding_type.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + fused_params={}, + batch_size=batch_size, + world_size=world_size, + num_batches=num_batches, + pooling_factor=pooling_factor, + input_type=input_type, + pipelines=pipeline, + profile=profile, + ) + + +def _generate_data( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + num_float_features: int = 10, + num_batches: int = 100, + batch_size: int = 4096, + pooling_factor: int = 10, + input_type: str = "kjt", +) -> List[ModelInput]: + if input_type == "kjt": + return [ + ModelInput.generate( + tables=tables, + weighted_tables=weighted_tables, + batch_size=batch_size, + num_float_features=num_float_features, + pooling_avg=pooling_factor, + ) + for _ in range(num_batches) + ] + else: + return [ + TdModelInput.generate( + tables=tables, + weighted_tables=weighted_tables, + batch_size=batch_size, + num_float_features=num_float_features, + pooling_avg=pooling_factor, + ) + for _ in range(num_batches) + ] + + +def _generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Optional[Dict[str, Any]] = None, +) -> Tuple[nn.Module, Optimizer]: + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ).to(device) + optimizer = optim.SGD( + [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ], + lr=0.1, + ) + return sharded_model, optimizer + + +def runner( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + rank: int, + sharding_type: str, + kernel_type: str, + fused_params: Dict[str, Any], + batch_size: int, + world_size: int, + num_batches: int, + pooling_factor: int, + input_type: str, + pipelines: str, + profile: str, +) -> None: + + torch.autograd.set_detect_anomaly(True) + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl", + use_deterministic_algorithms=False, + ) as ctx: + + unsharded_model = TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + sparse_device=torch.device("meta"), + over_arch_clazz=TestOverArchLarge, + ) + + sharded_model, optimizer = _generate_sharded_model_and_optimizer( + model=unsharded_model, + sharding_type=sharding_type, + kernel_type=kernel_type, + # pyre-ignore + pg=ctx.pg, + device=ctx.device, + fused_params={ + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + }, + ) + bench_inputs = _generate_data( + tables=tables, + weighted_tables=weighted_tables, + num_float_features=10, + num_batches=num_batches, + batch_size=batch_size, + pooling_factor=pooling_factor, + input_type=input_type, + ) + for pipeline_clazz in _gen_pipelines(pipelines=pipelines): + if pipeline_clazz == TrainPipelineSemiSync: + # pyre-ignore [28] + pipeline = pipeline_clazz( + model=sharded_model, + optimizer=optimizer, + device=ctx.device, + start_batch=0, + ) + else: + pipeline = pipeline_clazz( + model=sharded_model, + optimizer=optimizer, + device=ctx.device, + ) + pipeline.progress(iter(bench_inputs)) + + def _func_to_benchmark( + bench_inputs: List[ModelInput], + model: nn.Module, + pipeline: TrainPipeline, + ) -> None: + dataloader = iter(bench_inputs) + while True: + try: + pipeline.progress(dataloader) + except StopIteration: + break + + result = benchmark_func( + name=pipeline_clazz.__name__, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=profile, + world_size=world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, + rank=rank, + ) + if rank == 0: + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py new file mode 100644 index 000000000..a563281ca --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -0,0 +1,2497 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy + +import unittest +from contextlib import ExitStack +from dataclasses import dataclass +from functools import partial +from typing import cast, Dict, List, Optional, Tuple, Type, Union +from unittest.mock import MagicMock + +import torch +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torch import nn, optim +from torch._dynamo.testing import reduce_to_scalar_loss +from torch._dynamo.utils import counters +from torch.fx._symbolic_trace import is_fx_tracing +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.fp_embeddingbag import ( + FeatureProcessedEmbeddingBagCollectionSharder, + ShardedFeatureProcessedEmbeddingBagCollection, +) +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + table_wise, +) +from torchrec.distributed.test_utils.test_model import ( + ModelInput, + TestEBCSharder, + TestModelWithPreproc, + TestModelWithPreprocCollectionArgs, + TestNegSamplingModule, + TestPositionWeightedPreprocModule, + TestSparseNN, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( + create_module_and_freeze, +) +from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( + TrainPipelineSparseDistTestBase, +) +from torchrec.distributed.train_pipeline.train_pipelines import ( + EvalPipelineSparseDist, + PrefetchTrainPipelineSparseDist, + StagedTrainPipeline, + TrainPipelineBase, + TrainPipelinePT2, + TrainPipelineSemiSync, + TrainPipelineSparseDist, + TrainPipelineSparseDistCompAutograd, +) +from torchrec.distributed.train_pipeline.utils import ( + DataLoadingThread, + EmbeddingPipelinedForward, + get_h2d_func, + PipelinedForward, + PipelinedPostproc, + PipelineStage, + SparseDataDistUtil, + StageOut, + TrainPipelineContext, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardingEnv, + ShardingPlan, + ShardingType, +) +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.pt2.utils import kjt_for_pt2_tracing +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.streamable import Pipelineable + + +@dataclass +class ModelInputSimple(Pipelineable): + float_features: torch.Tensor + label: torch.Tensor + + def to(self, device: torch.device, non_blocking: bool) -> "ModelInputSimple": + return ModelInputSimple( + float_features=self.float_features.to( + device=device, non_blocking=non_blocking + ), + label=self.label.to(device=device, non_blocking=non_blocking), + ) + + def record_stream(self, stream: torch.Stream) -> None: + self.float_features.record_stream(stream) + self.label.record_stream(stream) + + +class TestModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = nn.Linear(10, 1) + self.loss_fn = nn.BCEWithLogitsLoss() + self._dummy_setting: str = "dummy" + + def forward( + self, model_input: ModelInputSimple + ) -> Tuple[torch.Tensor, torch.Tensor]: + pred = self.model(model_input.float_features) + loss = self.loss_fn(pred, model_input.label) + return (loss, pred) + + +class Tracer(torch.fx.Tracer): + _leaf_module_names: List[str] + + def __init__(self, leaf_module_names: Optional[List[str]] = None) -> None: + super().__init__() + self._leaf_module_names = leaf_module_names or [] + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + if module_qualified_name in self._leaf_module_names: + return True + return super().is_leaf_module(m, module_qualified_name) + + +class TrainPipelineBaseTest(unittest.TestCase): + def setUp(self) -> None: + self.device = torch.device("cuda:0") + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_equal_to_non_pipelined(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelineBase(model_gpu, optimizer_gpu, self.device) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + # Results will be close but not exactly equal as one model is on CPU and other on GPU + # If both were on GPU, the results will be exactly the same + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + +class TrainPipelinePT2Test(unittest.TestCase): + def setUp(self) -> None: + self.device = torch.device("cuda:0") + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + def gen_eb_conf_list(self, is_weighted: bool = False) -> List[EmbeddingBagConfig]: + weighted_prefix = "weighted_" if is_weighted else "" + + return [ + EmbeddingBagConfig( + num_embeddings=256, + embedding_dim=12, + name=weighted_prefix + "table_0", + feature_names=[weighted_prefix + "f0"], + ), + EmbeddingBagConfig( + num_embeddings=256, + embedding_dim=12, + name=weighted_prefix + "table_1", + feature_names=[weighted_prefix + "f1"], + ), + ] + + def gen_model( + self, device: torch.device, ebc_list: List[EmbeddingBagConfig] + ) -> nn.Module: + class M_ebc(torch.nn.Module): + def __init__(self, vle: EmbeddingBagCollection) -> None: + super().__init__() + self.model = vle + + def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]: + kt: KeyedTensor = self.model(x) + return list(kt.to_dict().values()) + + return M_ebc( + EmbeddingBagCollection( + device=device, + tables=ebc_list, + ) + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_equal_to_non_pipelined(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelinePT2(model_gpu, optimizer_gpu, self.device) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pre_compile_fn(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + + def pre_compile_fn(model: nn.Module) -> None: + # pyre-fixme[16]: `Module` has no attribute `_dummy_setting`. + model._dummy_setting = "dummy modified" + + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model_gpu, optimizer_gpu, self.device, pre_compile_fn=pre_compile_fn + ) + self.assertEqual(model_gpu._dummy_setting, "dummy") + for _ in range(len(data)): + pipeline.progress(dataloader) + self.assertEqual(model_gpu._dummy_setting, "dummy modified") + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_equal_to_non_pipelined_with_input_transformer(self) -> None: + cpu = torch.device("cpu:0") + eb_conf_list = self.gen_eb_conf_list() + eb_conf_list_weighted = self.gen_eb_conf_list(is_weighted=True) + + model_cpu = self.gen_model(cpu, eb_conf_list) + model_gpu = self.gen_model(self.device, eb_conf_list).to(self.device) + + _, local_model_inputs = ModelInput.generate( + batch_size=10, + world_size=4, + num_float_features=8, + tables=eb_conf_list, + weighted_tables=eb_conf_list_weighted, + variable_batch_size=False, + ) + + model_gpu.load_state_dict(model_cpu.state_dict()) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `parameters`. + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `parameters`. + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + + data = [ + i.idlist_features + for i in local_model_inputs + if isinstance(i.idlist_features, KeyedJaggedTensor) + ] + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss = reduce_to_scalar_loss(loss) + loss.backward() + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + torch.testing.assert_close(pred_gpu.cpu(), pred) + + +class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_feature_processed_ebc(self) -> None: + # don't need weighted tables here + self.weighted_tables = [] + + sharder = cast( + ModuleSharder[nn.Module], FeatureProcessedEmbeddingBagCollectionSharder() + ) + + class DummyWrapper(nn.Module): + def __init__(self, sparse_arch): + super().__init__() + self.m = sparse_arch + + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: + return self.m(model_input.idlist_features) + + max_feature_lengths = [10, 10, 12, 12] + sparse_arch = DummyWrapper( + create_module_and_freeze( + tables=self.tables, + device=self.device, + use_fp_collection=False, + max_feature_lengths=max_feature_lengths, + ) + ) + compute_kernel = EmbeddingComputeKernel.FUSED.value + module_sharding_plan = construct_module_sharding_plan( + sparse_arch.m._fp_ebc, + per_param_sharding={ + "table_0": table_wise(rank=0, compute_kernel=compute_kernel), + "table_1": table_wise(rank=0, compute_kernel=compute_kernel), + "table_2": table_wise(rank=0, compute_kernel=compute_kernel), + "table_3": table_wise(rank=0, compute_kernel=compute_kernel), + }, + local_size=1, + world_size=1, + device_type=self.device.type, + sharder=sharder, + ) + sharded_sparse_arch_no_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(self.pg), + sharders=[sharder], + device=self.device, + ) + + sharded_sparse_arch_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(self.pg), + sharders=[sharder], + device=self.device, + ) + + copy_state_dict( + sharded_sparse_arch_no_pipeline.state_dict(), + sharded_sparse_arch_pipeline.state_dict(), + ) + + data = self._generate_data( + num_batches=5, batch_size=1, max_feature_lengths=max_feature_lengths + ) + dataloader = iter(data) + + optimizer_no_pipeline = optim.SGD( + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 + ) + optimizer_pipeline = optim.SGD( + sharded_sparse_arch_pipeline.parameters(), lr=0.1 + ) + + pipeline = self.pipeline_class( + sharded_sparse_arch_pipeline, + optimizer_pipeline, + self.device, + ) + + for batch in data[:-2]: + batch = batch.to(self.device) + optimizer_no_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_no_pipeline(batch) + loss.backward() + optimizer_no_pipeline.step() + + pred_pipeline = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred_pipeline.cpu(), pred.cpu())) + + self.assertEqual(len(pipeline._pipelined_modules), 1) + self.assertIsInstance( + pipeline._pipelined_modules[0], + ShardedFeatureProcessedEmbeddingBagCollection, + ) + + def _setup_pipeline( + self, + sharder: EmbeddingBagCollectionSharder, + execute_all_batches: bool, + enable_fsdp: bool = False, + unsharded_model: Optional[nn.Module] = None, + ) -> TrainPipelineSparseDist[ModelInput, torch.Tensor]: + if unsharded_model is None: + unsharded_model = self._setup_model(enable_fsdp=enable_fsdp) + + distributed_model = DistributedModelParallel( + unsharded_model, + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ) + optimizer_distributed = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(distributed_model.named_parameters())), + lambda params: optim.SGD(params, lr=0.1), + ) + return self.pipeline_class( + model=distributed_model, + optimizer=optimizer_distributed, + device=self.device, + execute_all_batches=execute_all_batches, + ) + + def _setup_cpu_model_and_opt(self) -> Tuple[TestSparseNN, optim.SGD]: + cpu_model = TestSparseNN( + tables=self.tables, weighted_tables=self.weighted_tables + ) + cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=0.1) + return cpu_model, cpu_optimizer + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + execute_all_batches=st.booleans(), + ) + def test_equal_to_non_pipelined( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = {} + fused_params_pipelined = {} + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + ) + if not execute_all_batches: + data = data[:-2] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred, pred_pipeline) + + self.assertRaises(StopIteration, pipeline.progress, dataloader) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given(execute_all_batches=st.booleans()) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: + unsharded_model = self._setup_model(enable_fsdp=True) + leaf_module_names = [] + for i, _ in unsharded_model.named_children(): + leaf_module_names.append(i) + # Simulate a corner case where we trace into the child module + # so direct children is not part of the graph. This will break the + # original pipelining logic, because the leaf module is only the direct + # children, and when the root node calls directly into child's child. + # It was broken because the child'child is a FSDP module and it + # breaks because FSDP is not trace-able + leaf_module_names.remove("over") + leaf_module_names.append("over.dhn_arch") + tracer = Tracer(leaf_module_names=leaf_module_names) + graph = tracer.trace(unsharded_model) + + traced_model = torch.fx.GraphModule(unsharded_model, graph) + pipeline = self._setup_pipeline( + TestEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + ), + execute_all_batches, + enable_fsdp=True, + unsharded_model=traced_model, + ) + cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt() + data = self._generate_data() + + dataloader = iter(data) + if not execute_all_batches: + data = data[:-2] + + for batch in data: + cpu_optimizer.zero_grad() + loss, pred = cpu_model(batch) + loss.backward() + cpu_optimizer.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(pred_gpu.device, self.device) + self.assertEqual(pred_gpu.cpu().size(), pred.size()) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_multi_dataloader_pipelining(self) -> None: + pipeline = self._setup_pipeline( + sharder=TestEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + ), + execute_all_batches=True, + ) + cpu_model, cpu_optimizer = self._setup_cpu_model_and_opt() + data = self._generate_data(num_batches=7) + + cpu_preds = [] + for batch in data: + cpu_optimizer.zero_grad() + loss, pred = cpu_model(batch) + loss.backward() + cpu_optimizer.step() + cpu_preds.append(pred) + + dataloaders = [iter(data[:-3]), iter(data[-3:-2]), iter(data[-2:])] + gpu_preds = [] + for dataloader in dataloaders: + while True: + try: + pred = pipeline.progress(dataloader) + self.assertEqual(pred.device, self.device) + self.assertEqual(len(pipeline._pipelined_modules), 2) + gpu_preds.append(pred.cpu()) + except StopIteration: + break + + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(len(cpu_preds), len(gpu_preds)) + self.assertTrue( + all( + cpu_pred.size() == gpu_pred.size() + for cpu_pred, gpu_pred in zip(cpu_preds, gpu_preds) + ) + ) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_custom_fwd( + self, + ) -> None: + data = self._generate_data( + num_batches=4, + batch_size=32, + ) + dataloader = iter(data) + + fused_params_pipelined = {} + sharding_type = ShardingType.ROW_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + sharded_model_pipelined: torch.nn.Module + + model = self._setup_model() + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + + def custom_model_fwd( + input: Optional[ModelInput], + ) -> Tuple[torch.Tensor, torch.Tensor]: + loss, pred = sharded_model_pipelined(input) + batch_size = pred.size(0) + return loss, pred.expand(batch_size * 2, -1) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + custom_model_fwd=custom_model_fwd, + ) + + for _ in data: + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + self.assertEqual(pred_pipeline.size(0), 64) + + +class TrainPipelineAttachDetachTest(TrainPipelineSparseDistTestBase): + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + with_postproc=st.booleans(), + pipeline_class=st.sampled_from( + [ + TrainPipelineSparseDist, + TrainPipelineSemiSync, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_model_detach_during_train( + self, + with_postproc: bool, + # pyre-ignore + pipeline_class: Union[TrainPipelineSparseDist, TrainPipelineSemiSync], + ) -> None: + """ + Test the scenario in which: + 1) Model training with pipeline.progress() + 2) Mid-training, model is detached + 3) Check that fwd of detached model is same as non-pipelined model + 4) Pipeline progress() re-attaches the model and we can continue progressing + """ + num_batches = 7 + batch_size = 32 + + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + fused_params = {} + pipelined_forward_type = ( + PipelinedForward + if pipeline_class == TrainPipelineSparseDist + else EmbeddingPipelinedForward + ) + + postproc_module = None + if with_postproc: + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = self._setup_model(postproc_module=postproc_module) + + data = self._generate_data( + num_batches=num_batches, + batch_size=batch_size, + ) + dataloader = iter(data) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + # pyre-ignore + pipeline = pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + for i in range(3): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + # Check internal states + ebcs = [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.ebc, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.weighted_ebc, + ] + for ebc in ebcs: + self.assertIsInstance(ebc.forward, pipelined_forward_type) + + if with_postproc: + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + PipelinedPostproc, + ) + + detached_model = pipeline.detach() + + if with_postproc: + # Check we removed pipelined postproc wrapping after detach + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + TestNegSamplingModule, + ) + + # Check internal states + for ebc in ebcs: + self.assertNotIsInstance(ebc.forward, pipelined_forward_type) + + # Check fwd of detached model is same as non-pipelined model + with torch.no_grad(): + batch = data[3].to(self.device) + _, detached_out = detached_model(batch) + _, out = sharded_model(batch) + self.assertTrue(torch.equal(detached_out, out)) + + # Check that pipeline re-attaches the model again without issues + for i in range(3, 7): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + for ebc in ebcs: + self.assertIsInstance(ebc.forward, pipelined_forward_type) + + if with_postproc: + # Check we have pipelined postproc after re-attaching + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + PipelinedPostproc, + ) + + # Check pipeline exhausted + self.assertRaises(StopIteration, pipeline.progress, dataloader) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + with_postproc=st.booleans(), + pipeline_class=st.sampled_from( + [ + TrainPipelineSparseDist, + TrainPipelineSemiSync, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_model_detach_after_train( + self, + with_postproc: bool, + # pyre-ignore + pipeline_class: Union[TrainPipelineSparseDist, TrainPipelineSemiSync], + ) -> None: + """ + Test the scenario in which: + 1) Model training with pipeline.progress() + 2) Pipeline exhausts dataloader and raises StopIteration + 4) Model is detached + 5) Check that fwd of detached model is same as non-pipelined model + 6) Pipeline progress() with new dataloader re-attaches model + """ + num_batches = 7 + batch_size = 32 + + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + fused_params = {} + pipelined_forward_type = ( + PipelinedForward + if pipeline_class == TrainPipelineSparseDist + else EmbeddingPipelinedForward + ) + + postproc_module = None + if with_postproc: + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = self._setup_model(postproc_module=postproc_module) + + data = self._generate_data( + num_batches=num_batches, + batch_size=batch_size, + ) + dataloader = iter(data) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + # pyre-ignore + pipeline = pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + for i in range(7): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + # Check pipeline exhausted + self.assertRaises(StopIteration, pipeline.progress, dataloader) + + if with_postproc: + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + PipelinedPostproc, + ) + + detached_model = pipeline.detach() + + if with_postproc: + # Check we removed pipelined postproc wrapping after detach + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + TestNegSamplingModule, + ) + + # Check internal states + ebcs = [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.ebc, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.weighted_ebc, + ] + for ebc in ebcs: + self.assertNotIsInstance(ebc.forward, pipelined_forward_type) + + # Check fwd of detached model is same as non-pipelined model + with torch.no_grad(): + for i in range(2): + batch = data[i].to(self.device) + _, detached_out = detached_model(batch) + _, out = sharded_model(batch) + self.assertTrue(torch.equal(detached_out, out)) + + # Provide new loaded dataloader and check model is re-attached + data = self._generate_data( + num_batches=4, + batch_size=32, + ) + dataloader = iter(data) + + for i in range(4): + batch = data[i] + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + if with_postproc: + self.assertIsInstance( + # pyre-ignore + sharded_model_pipelined.module.postproc_module, + PipelinedPostproc, + ) + + # Check pipeline exhausted + self.assertRaises(StopIteration, pipeline.progress, dataloader) + + +class TrainPipelinePostprocTest(TrainPipelineSparseDistTestBase): + def setUp(self) -> None: + super().setUp() + self.num_batches = 10 + self.batch_size = 32 + self.sharding_type = ShardingType.TABLE_WISE.value + self.kernel_type = EmbeddingComputeKernel.FUSED.value + self.fused_params = {} + + def _check_output_equal( + self, + model: torch.nn.Module, + sharding_type: str, + max_feature_lengths: Optional[List[int]] = None, + ) -> Tuple[nn.Module, TrainPipelineSparseDist[ModelInput, torch.Tensor]]: + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=max_feature_lengths, + ) + dataloader = iter(data) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, self.kernel_type, self.fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, self.kernel_type, self.fused_params + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + for i in range(self.num_batches): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + pred_pipelined = pipeline.progress(dataloader) + self.assertTrue(torch.equal(pred, pred_pipelined)) + + return sharded_model_pipelined, pipeline + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_modules_share_postproc(self) -> None: + """ + Setup: postproc module takes in input batch and returns modified + input batch. EBC and weighted EBC inside model sparse arch subsequently + uses this modified KJT. + + Test case where single postproc module is shared by multiple sharded modules + and output of postproc module needs to be transformed in the SAME way + """ + + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(postproc_module=postproc_module) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EC and EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: + """ + Test case where arguments to postproc module is some non-modifying + transformation of the input batch (no nested postproc modules) AND + arguments to multiple sharded modules can be derived from the output + of different postproc modules (i.e. postproc modules not shared). + """ + model = TestModelWithPreproc( + tables=self.tables[:-1], # ignore last table as postproc will remove + weighted_tables=self.weighted_tables[:-1], # ignore last table + device=self.device, + ) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EBC and weighted EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + + pipelined_ebc = pipeline._pipelined_modules[0] + pipelined_weighted_ebc = pipeline._pipelined_modules[1] + + # Check pipelined args + for ebc in [pipelined_ebc, pipelined_weighted_ebc]: + self.assertEqual(len(ebc.forward._args), 1) + self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) + self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) + self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) + self.assertIsInstance( + ebc.forward._args[0].postproc_modules[0], PipelinedPostproc + ) + self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) + + self.assertEqual( + pipelined_ebc.forward._args[0].postproc_modules[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_nonweighted`. + pipelined_model.module.postproc_nonweighted, + ) + self.assertEqual( + pipelined_weighted_ebc.forward._args[0].postproc_modules[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_weighted`. + pipelined_model.module.postproc_weighted, + ) + + # postproc args + self.assertEqual(len(pipeline._pipelined_postprocs), 2) + input_attr_names = {"idlist_features", "idscore_features"} + for i in range(len(pipeline._pipelined_postprocs)): + postproc_mod = pipeline._pipelined_postprocs[i] + self.assertEqual(len(postproc_mod._args), 1) + + input_attr_name = postproc_mod._args[0].input_attrs[1] + self.assertTrue(input_attr_name in input_attr_names) + self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) + input_attr_names.remove(input_attr_name) + + self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) + # no parent postproc module in FX graph + self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_recursive(self) -> None: + """ + Test recursive case where multiple arguments to postproc module is derived + from output of another postproc module. For example, + + out_a, out_b, out_c = postproc_1(input) + out_d = postproc_2(out_a, out_b) + # do something with out_c + out = ebc(out_d) + """ + extra_input = ModelInput.generate( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + postproc_module=postproc_module, + ) + + pipelined_model, pipeline = self._check_output_equal(model, self.sharding_type) + + # Check that both EBC and weighted EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + + pipelined_ebc = pipeline._pipelined_modules[0] + pipelined_weighted_ebc = pipeline._pipelined_modules[1] + + # Check pipelined args + for ebc in [pipelined_ebc, pipelined_weighted_ebc]: + self.assertEqual(len(ebc.forward._args), 1) + self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) + self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) + self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) + self.assertIsInstance( + ebc.forward._args[0].postproc_modules[0], PipelinedPostproc + ) + self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) + + self.assertEqual( + pipelined_ebc.forward._args[0].postproc_modules[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_nonweighted`. + pipelined_model.module.postproc_nonweighted, + ) + self.assertEqual( + pipelined_weighted_ebc.forward._args[0].postproc_modules[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_weighted`. + pipelined_model.module.postproc_weighted, + ) + + # postproc args + self.assertEqual(len(pipeline._pipelined_postprocs), 3) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_postproc_module`. + parent_postproc_mod = pipelined_model.module._postproc_module + + for postproc_mod in pipeline._pipelined_postprocs: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_nonweighted`. + if postproc_mod == pipelined_model.module.postproc_nonweighted: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] + self.assertEqual(args.input_attrs, ["", "idlist_features"]) + self.assertEqual(args.is_getitems, [False, False]) + self.assertEqual(len(args.postproc_modules), 2) + self.assertEqual( + args.postproc_modules[0], + parent_postproc_mod, + ) + self.assertEqual(args.postproc_modules[1], None) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_weighted`. + elif postproc_mod == pipelined_model.module.postproc_weighted: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] + self.assertEqual(args.input_attrs, ["", "idscore_features"]) + self.assertEqual(args.is_getitems, [False, False]) + self.assertEqual(len(args.postproc_modules), 2) + self.assertEqual( + args.postproc_modules[0], + parent_postproc_mod, + ) + self.assertEqual(args.postproc_modules[1], None) + elif postproc_mod == parent_postproc_mod: + self.assertEqual(len(postproc_mod._args), 1) + args = postproc_mod._args[0] + self.assertEqual(args.input_attrs, [""]) + self.assertEqual(args.is_getitems, [False]) + self.assertEqual(args.postproc_modules, [None]) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_postproc_inputs_has_trainable_params(self) -> None: + """ + Test case where postproc module sits in front of sharded module but this cannot be + safely pipelined as it contains trainable params in its child modules + """ + max_feature_lengths = { + "feature_0": 10, + "feature_1": 10, + "feature_2": 10, + "feature_3": 10, + } + + postproc_module = TestPositionWeightedPreprocModule( + max_feature_lengths=max_feature_lengths, + device=self.device, + ) + + model = self._setup_model(postproc_module=postproc_module) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=list(max_feature_lengths.values()), + ) + dataloader = iter(data) + + pipeline.progress(dataloader) + + # Check that no modules are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 0) + self.assertEqual(len(pipeline._pipelined_postprocs), 0) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_postproc_trainable_params_recursive( + self, + ) -> None: + max_feature_lengths = { + "feature_0": 10, + "feature_1": 10, + "feature_2": 10, + "feature_3": 10, + } + + postproc_module = TestPositionWeightedPreprocModule( + max_feature_lengths=max_feature_lengths, + device=self.device, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + postproc_module=postproc_module, + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + max_feature_lengths=list(max_feature_lengths.values()), + ) + dataloader = iter(data) + pipeline.progress(dataloader) + + # Check that no modules are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 0) + self.assertEqual(len(pipeline._pipelined_postprocs), 0) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_invalid_postproc_inputs_modify_kjt_recursive(self) -> None: + """ + Test case where postproc module cannot be pipelined because at least one of args + is derived from output of another postproc module whose arg(s) cannot be derived + from input batch (i.e. it has modifying transformations) + """ + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + postproc_module=None, + run_postproc_inline=True, # run postproc inline, outside a module + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + ) + dataloader = iter(data) + pipeline.progress(dataloader) + + # Check that only weighted EBC is pipelined + self.assertEqual(len(pipeline._pipelined_modules), 1) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) + self.assertEqual(pipeline._pipelined_modules[0]._is_weighted, True) + self.assertEqual( + pipeline._pipelined_postprocs[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_weighted`. + sharded_model_pipelined.module.postproc_weighted, + ) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_fwd_values_cached(self) -> None: + """ + Test to check that during model forward, the postproc module pipelined uses the + saved result from previous iteration(s) and doesn't perform duplicate work + check that fqns for ALL postproc modules are populated in the right train pipeline + context. + """ + extra_input = ModelInput.generate( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + + model = TestModelWithPreproc( + tables=self.tables[:-1], + weighted_tables=self.weighted_tables[:-1], + device=self.device, + postproc_module=postproc_module, + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, self.sharding_type, self.kernel_type, self.fused_params + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + pipeline_postproc=True, + ) + + data = self._generate_data( + num_batches=self.num_batches, + batch_size=self.batch_size, + ) + dataloader = iter(data) + + pipeline.progress(dataloader) + + # This was second context that was appended + current_context = pipeline.contexts[0] + cached_results = current_context.postproc_fwd_results + self.assertEqual( + list(cached_results.keys()), + ["_postproc_module", "postproc_nonweighted", "postproc_weighted"], + ) + + # next context cached results should be empty + next_context = pipeline.contexts[1] + next_cached_results = next_context.postproc_fwd_results + self.assertEqual(len(next_cached_results), 0) + + # After progress, next_context should be populated + pipeline.progress(dataloader) + self.assertEqual( + list(next_cached_results.keys()), + ["_postproc_module", "postproc_nonweighted", "postproc_weighted"], + ) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_nested_postproc(self) -> None: + """ + If postproc module is nested, we should still be able to pipeline it + """ + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(postproc_module=postproc_module) + + class ParentModule(nn.Module): + def __init__( + self, + nested_model: nn.Module, + ) -> None: + super().__init__() + self.nested_model = nested_model + + def forward( + self, + input: ModelInput, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.nested_model(input) + + model = ParentModule(model) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EC and EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(len(pipeline._pipelined_postprocs), 1) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_with_collection_args(self) -> None: + """ + Exercises scenario when postproc module has an argument that is a list or dict + with some elements being: + * static scalars + * static tensors (e.g. torch.ones()) + * tensors derived from input batch (e.g. input.idlist_features["feature_0"]) + * tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"])) + """ + test_runner = self + + class PostprocOuter(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + ) -> torch.Tensor: + return model_input.float_features * 0.1 + + class PostprocInner(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + input_list: List[Union[torch.Tensor, int]], + input_dict: Dict[str, Union[torch.Tensor, int]], + ) -> ModelInput: + if not is_fx_tracing(): + for idx, value in enumerate(input_list): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_list[{idx}] was a fx.Node: {value}" + ) + model_input.float_features += value + + for key, value in input_dict.items(): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_dict[{key}] was a fx.Node: {value}" + ) + model_input.float_features += value + + return model_input + + model = TestModelWithPreprocCollectionArgs( + tables=self.tables[:-1], # ignore last table as postproc will remove + weighted_tables=self.weighted_tables[:-1], # ignore last table + device=self.device, + postproc_module_outer=PostprocOuter(), + postproc_module_nested=PostprocInner(), + ) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # both EC end EBC are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + # both outer and nested postproces are pipelined + self.assertEqual(len(pipeline._pipelined_postprocs), 4) + + +class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=8, deadline=None) + # pyre-ignore[56] + @given( + start_batch=st.sampled_from([0, 6]), + stash_gradients=st.booleans(), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + zch=st.booleans(), + ) + def test_equal_to_non_pipelined( + self, + start_batch: int, + stash_gradients: bool, + sharding_type: str, + kernel_type: str, + zch: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + # ZCH only supports row-wise currently + assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value)) + torch.autograd.set_detect_anomaly(True) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "stochastic_rounding": False, + } + fused_params_pipelined = { + **fused_params, + } + + model = self._setup_model(zch=zch) + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = TrainPipelineSemiSync( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + start_batch=start_batch, + stash_gradients=stash_gradients, + ) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse_forward`. + prior_sparse_out = sharded_model._dmp_wrapped_module.sparse_forward( + data[0].to(self.device) + ) + prior_batch = data[0].to(self.device) + prior_stashed_grads = None + batch_index = 0 + sparse_out = None + for batch in data[1:]: + batch_index += 1 + # Forward + backward w/o pipelining + batch = batch.to(self.device) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dense_forward`. + loss, pred = sharded_model._dmp_wrapped_module.dense_forward( + prior_batch, prior_sparse_out + ) + if batch_index - 1 >= start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + loss.backward() + + stashed_grads = None + if batch_index - 1 >= start_batch and stash_gradients: + stashed_grads = [] + for param in optim.param_groups[0]["params"]: + stashed_grads.append( + param.grad.clone() if param.grad is not None else None + ) + param.grad = None + + if prior_stashed_grads is not None: + for param, stashed_grad in zip( + optim.param_groups[0]["params"], prior_stashed_grads + ): + param.grad = stashed_grad + optim.step() + optim.zero_grad() + + if batch_index - 1 < start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + prior_stashed_grads = stashed_grads + prior_batch = batch + prior_sparse_out = sparse_out + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if batch_index >= start_batch: + self.assertTrue( + pipeline.is_semi_sync(), msg="pipeline is not semi_sync" + ) + else: + self.assertFalse(pipeline.is_semi_sync(), msg="pipeline is semi_sync") + self.assertTrue( + torch.equal(pred, pred_pipeline), + msg=f"batch {batch_index} doesn't match", + ) + + # one more batch + pred_pipeline = pipeline.progress(dataloader) + self.assertRaises(StopIteration, pipeline.progress, dataloader) + + +class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + execute_all_batches=st.booleans(), + weight_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + 0.6, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + ) + def test_equal_to_non_pipelined( + self, + execute_all_batches: bool, + weight_precision: DataType, + cache_precision: DataType, + load_factor: float, + sharding_type: str, + kernel_type: str, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + mixed_precision: bool = weight_precision != cache_precision + self._set_table_weights_precision(weight_precision) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = PrefetchTrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + ) + + if not execute_all_batches: + data = data[:-3] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if not mixed_precision: + # Rounding error is expected when using different precisions for weights and cache + self.assertTrue(torch.equal(pred, pred_pipeline)) + else: + torch.testing.assert_close(pred, pred_pipeline) + + +class DataLoadingThreadTest(unittest.TestCase): + def test_fetch_data(self) -> None: + data = [] + for i in range(7): + data.append(torch.tensor([i])) + data_iter = iter(data) + data_loader = DataLoadingThread(torch.device("cpu"), data_iter, True) + data_loader.start() + for i in range(7): + item = data_loader.get_next_batch() + self.assertEqual(item.item(), i) + + self.assertIsNone(data_loader.get_next_batch(False)) + with self.assertRaises(StopIteration): + data_loader.get_next_batch(True) + data_loader.stop() + + +class EvalPipelineSparseDistTest(unittest.TestCase): + def test_processing(self) -> None: + mock_model = MagicMock() + + def model_side_effect( + item: Pipelineable, + ) -> Tuple[Optional[Pipelineable], Pipelineable]: + return (None, item) + + mock_model.side_effect = model_side_effect + mock_optimizer = MagicMock() + + class MockPipeline(EvalPipelineSparseDist): + def __init__(self, model, optimizer, device: torch.device) -> None: + super().__init__(model, optimizer, device) + + def _init_pipelined_modules( + self, + item: Pipelineable, + context: TrainPipelineContext, + pipelined_forward: Type[PipelinedForward], + ) -> None: + pass + + def _start_sparse_data_dist( + self, item: Pipelineable, context: TrainPipelineContext + ) -> None: + pass + + def _wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: + pass + + pipeline = MockPipeline(mock_model, mock_optimizer, torch.device("cpu")) + + data = [] + for i in range(7): + data.append(torch.tensor([i])) + data_iter = iter(data) + + for i in range(7): + item = pipeline.progress(data_iter) + self.assertEqual(item.item(), i) + + self.assertRaises(StopIteration, pipeline.progress, data_iter) + + +class StagedTrainPipelineTest(TrainPipelineSparseDistTestBase): + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipelining(self) -> None: + model = self._setup_model() + + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type + ) + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type + ) + + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + num_batches = 12 + data = self._generate_data( + num_batches=num_batches, + batch_size=32, + ) + + non_pipelined_outputs = [] + for batch in data: + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + non_pipelined_outputs.append(pred) + + def gpu_postproc(x: StageOut) -> StageOut: + return x + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + data_dist_stream=torch.cuda.Stream(), + apply_jit=False, + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_postproc", + runnable=gpu_postproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_postproc_1", + runnable=gpu_postproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_postproc_2", + runnable=gpu_postproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + ] + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream() + ) + dataloader = iter(data) + + pipelined_out = [] + num_batches_processed = 0 + + while model_in := pipeline.progress(dataloader): + num_batches_processed += 1 + optim_pipelined.zero_grad() + loss, pred = sharded_model_pipelined(model_in) + loss.backward() + optim_pipelined.step() + pipelined_out.append(pred) + + self.assertEqual(num_batches_processed, num_batches) + + self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) + for out, ref_out in zip(pipelined_out, non_pipelined_outputs): + torch.testing.assert_close(out, ref_out) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_flush(self) -> None: + model = self._setup_model() + + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type + ) + + def gpu_postproc(x: StageOut) -> StageOut: + return x + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + data_dist_stream=torch.cuda.Stream(), + apply_jit=False, + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_postproc", + runnable=gpu_postproc, + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + ] + + flush_end_called: int = 0 + + def on_flush_end() -> None: + nonlocal flush_end_called + flush_end_called += 1 + + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, + compute_stream=torch.cuda.current_stream(), + on_flush_end=on_flush_end, + ) + self.assertEqual(pipeline._flushing, False) + + data = self._generate_data( + num_batches=10, + batch_size=32, + ) + dataloader = iter(data) + + # Run pipeline for 1 iteration, now internal state should be: + # pipeline._stage_outputs = [stage 2 output, stage 1 output, stage 0 output] + # and we exhaust 4 batches from dataloader (3 + 1 in _fill_pipeline) + out = pipeline.progress(dataloader) + self.assertIsNotNone(out) + + # Flush pipeline + pipeline.set_flush(True) + + # Run pipeline for 3 iterations + # Iteration 1: pipeline returns output from second batch + # Iteration 2: pipeline returns output from third batch + # Iteration 3: pipeline returns output from fourth batch + for _ in range(3): + out = pipeline.progress(dataloader) + self.assertIsNotNone(out) + + # Flush end not reached + self.assertEqual(flush_end_called, 0) + + # After this iteration, pipeline has been completely flushed + out = pipeline.progress(dataloader) + self.assertEqual(flush_end_called, 1) + # output shouldn't be None as we restart pipeline + # this should be output from fifth batch + self.assertIsNotNone(out) + + # Pipeline internal state + self.assertEqual(pipeline._flushing, False) + self.assertIsNotNone(pipeline._stage_outputs[0]) + self.assertIsNotNone(pipeline._stage_outputs[1]) + self.assertIsNotNone(pipeline._stage_outputs[2]) + + # Check that we get 5 more iterations before pipeline exhausts all data + for _ in range(5): + out = pipeline.progress(dataloader) + self.assertIsNotNone(out) + + # Check that pipeline has exhausted all data + out = pipeline.progress(dataloader) + self.assertIsNone(out) + + # Flush end not called this time + self.assertEqual(flush_end_called, 1) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_model_detach(self) -> None: + model = self._setup_model() + + sharding_type = ShardingType.TABLE_WISE.value + fused_params = {} + kernel_type = EmbeddingComputeKernel.FUSED.value + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type + ) + + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + data_dist_stream=torch.cuda.Stream(), + apply_jit=False, + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + ] + + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, + compute_stream=torch.cuda.current_stream(), + ) + + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + for i in range(5): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + model_in = pipeline.progress(dataloader) + optim_pipelined.zero_grad() + loss_pred, pred_pipelined = sharded_model_pipelined(model_in) + loss_pred.backward() + optim_pipelined.step() + + self.assertTrue(torch.equal(pred, pred_pipelined)) + + # Check internal states + ebcs = [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.ebc, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model_pipelined.module.sparse.weighted_ebc, + ] + for ebc in ebcs: + self.assertIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1) + + detached_model = sdd.detach() + + # Check internal states + for ebc in ebcs: + self.assertNotIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 0) + + # Check we can run backward and optimizer ond detached model + batch = data[5].to(self.device) + loss_detached, detached_out = detached_model(batch) + loss_sharded, out = sharded_model(batch) + self.assertTrue(torch.equal(detached_out, out)) + loss_detached.backward() + loss_sharded.backward() + optim.step() + optim_pipelined.step() + + # Check fwd of detached model is same as non-pipelined model + with torch.no_grad(): + batch = data[6].to(self.device) + _, detached_out = detached_model(batch) + _, out = sharded_model(batch) + self.assertTrue(torch.equal(detached_out, out)) + + # Check that pipeline re-attaches the model again without issues + for i in range(5, 12): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + model_in = pipeline.progress(dataloader) + optim_pipelined.zero_grad() + loss_pred, pred_pipelined = sharded_model_pipelined(model_in) + loss_pred.backward() + optim_pipelined.step() + + self.assertTrue(torch.equal(pred, pred_pipelined)) + + for ebc in ebcs: + self.assertIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1) + + # Check pipeline exhausted + postproc_input = pipeline.progress(dataloader) + self.assertIsNone(postproc_input) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + ] + ), + ) + def test_pipelining_prefetch( + self, + sharding_type: str, + kernel_type: str, + cache_precision: DataType, + load_factor: float, + ) -> None: + model = self._setup_model() + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + num_batches = 12 + data = self._generate_data( + num_batches=num_batches, + batch_size=32, + ) + + non_pipelined_outputs = [] + for batch in data: + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + non_pipelined_outputs.append(pred) + + def gpu_postproc(x: StageOut) -> StageOut: + return x + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + data_dist_stream=torch.cuda.Stream(), + apply_jit=False, + prefetch_stream=torch.cuda.Stream(), + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + PipelineStage( + name="prefetch", + runnable=sdd.prefetch, + # pyre-ignore + stream=sdd.prefetch_stream, + fill_callback=sdd.load_prefetch, + ), + ] + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream() + ) + dataloader = iter(data) + + pipelined_out = [] + num_batches_processed = 0 + + while model_in := pipeline.progress(dataloader): + num_batches_processed += 1 + optim_pipelined.zero_grad() + loss, pred = sharded_model_pipelined(model_in) + loss.backward() + optim_pipelined.step() + pipelined_out.append(pred) + + self.assertEqual(num_batches_processed, num_batches) + + self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) + for out, ref_out in zip(pipelined_out, non_pipelined_outputs): + torch.testing.assert_close(out, ref_out) + + +class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + self.pipeline_class = TrainPipelineSparseDistCompAutograd + torch._dynamo.reset() + counters["compiled_autograd"].clear() + # Compiled Autograd don't work with Anomaly Mode + torch.autograd.set_detect_anomaly(False) + self._exit_stack = ExitStack() + self._exit_stack.enter_context( + # type: ignore[attr-defined] + torch._dynamo.config.patch( + optimize_ddp="python_reducer_without_compiled_forward" + ), + ) + + def tearDown(self) -> None: + self._exit_stack.close() + self.assertEqual(counters["compiled_autograd"]["captures"], 3) + return super().tearDown() + + @unittest.skip("Dynamo only supports FSDP with use_orig_params=True") + # pyre-ignore[56] + @given(execute_all_batches=st.booleans()) + def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: + super().test_pipelining_fsdp_pre_trace() + + @unittest.skip( + "TrainPipelineSparseDistTest.test_equal_to_non_pipelined was called from multiple different executors, which fails hypothesis HealthChek, so we skip it here" + ) + def test_equal_to_non_pipelined( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + super().test_equal_to_non_pipelined() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py new file mode 100644 index 000000000..56e6ac636 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import os +import unittest +from typing import Any, cast, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed as dist +from torch import nn, optim +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.test_utils.test_model import ( + ModelInput, + TestEBCSharder, + TestEBCSharderMCH, + TestSparseNN, +) +from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist +from torchrec.distributed.types import ModuleSharder, ShardingEnv +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig +from torchrec.test_utils import get_free_port, init_distributed_single_host + + +class TrainPipelineSparseDistTestBase(unittest.TestCase): + def setUp(self) -> None: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + backend = "gloo" + if torch.cuda.is_available(): + backend = "nccl" + self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1) + + num_features = 4 + num_weighted_features = 2 + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.device = torch.device("cuda:0") + self.pipeline_class = TrainPipelineSparseDist + + def tearDown(self) -> None: + super().tearDown() + dist.destroy_process_group(self.pg) + + def _generate_data( + self, + num_batches: int = 5, + batch_size: int = 1, + max_feature_lengths: Optional[List[int]] = None, + ) -> List[ModelInput]: + return [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=1, + num_float_features=10, + max_feature_lengths=max_feature_lengths, + )[0] + for i in range(num_batches) + ] + + def _set_table_weights_precision(self, dtype: DataType) -> None: + for i in range(len(self.tables)): + self.tables[i].data_type = dtype + + for i in range(len(self.weighted_tables)): + self.weighted_tables[i].data_type = dtype + + def _setup_model( + self, + model_type: Type[nn.Module] = TestSparseNN, + enable_fsdp: bool = False, + postproc_module: Optional[nn.Module] = None, + zch: bool = False, + ) -> nn.Module: + unsharded_model = model_type( + tables=self.tables, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + postproc_module=postproc_module, + zch=zch, + ) + if enable_fsdp: + unsharded_model.over.dhn_arch.linear0 = FSDP( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `dhn_arch`. + unsharded_model.over.dhn_arch.linear0 + ) + unsharded_model.over.dhn_arch.linear1 = FSDP( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `dhn_arch`. + unsharded_model.over.dhn_arch.linear1 + ) + # pyre-fixme[16]: `Module` has no attribute `dhn_arch`. + # pyre-fixme[16]: `Tensor` has no attribute `dhn_arch`. + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dhn_arch`. + unsharded_model.over.dhn_arch = FSDP(unsharded_model.over.dhn_arch) + + return unsharded_model + + def _generate_sharded_model_and_optimizer( + self, + model: nn.Module, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + ) -> Tuple[nn.Module, Optimizer]: + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + mc_sharder = TestEBCSharderMCH( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ), + cast( + ModuleSharder[nn.Module], + mc_sharder, + ), + ], + ) + # default fused optimizer is SGD w/ lr=0.1; we need to drop params + fused_named_parameters: List[str] = [ + x for x in DistributedModelParallel._sharded_parameter_names(sharded_model) + ] + optimizer = optim.SGD( + [ + y + for x, y in sharded_model.named_parameters() + if x not in fused_named_parameters + ], + lr=0.1, + ) + return sharded_model, optimizer diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py new file mode 100644 index 000000000..53fae9001 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import enum +import unittest +from typing import List +from unittest.mock import MagicMock + +import torch +from parameterized import parameterized + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule + +from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( + TrainPipelineSparseDistTestBase, +) +from torchrec.distributed.train_pipeline.utils import ( + _build_args_kwargs, + _get_node_args, + _rewrite_model, + ArgInfo, + PipelinedForward, + PipelinedPostproc, + TrainPipelineContext, +) +from torchrec.distributed.types import ShardingType +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class ModelType(enum.Enum): + VANILLA = "vanilla" + SHARDED = "sharded" + PIPELINED = "pipelined" + + +class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rewrite_model(self) -> None: + sharding_type = ShardingType.TABLE_WISE.value + kernel_type = EmbeddingComputeKernel.FUSED.value + fused_params = {} + + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=10, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(postproc_module=postproc_module) + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + # Try to rewrite model without ignored_postproc_modules defined, EBC forwards not overwritten to PipelinedForward due to KJT modification + _rewrite_model( + model=sharded_model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + ) + self.assertNotIsInstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model.module.sparse.ebc.forward, + PipelinedForward, + ) + self.assertNotIsInstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model.module.sparse.weighted_ebc.forward, + PipelinedForward, + ) + + # Now provide postproc module explicitly + _rewrite_model( + model=sharded_model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + pipeline_postproc=True, + ) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`. + self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward) + self.assertIsInstance( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model.module.sparse.weighted_ebc.forward, + PipelinedForward, + ) + self.assertEqual( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_module`. + sharded_model.module.postproc_module, + ) + self.assertEqual( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse`. + sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[ + 0 + ], + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `postproc_module`. + sharded_model.module.postproc_module, + ) + state_dict = sharded_model.state_dict() + missing_keys, unexpected_keys = sharded_model.load_state_dict(state_dict) + self.assertEqual(missing_keys, []) + self.assertEqual(unexpected_keys, []) + + def test_pipelined_postproc_state_dict(self) -> None: + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("weight", torch.tensor(1.0)) + + def forward(self, x): + return x + + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.test_module = TestModule() + + def forward(self, x): + return self.test_module(x) + + model = TestModel() + + rewritten_model = copy.deepcopy(model) + # pyre-ignore[8] + rewritten_model.test_module = PipelinedPostproc( + postproc_module=rewritten_model.test_module, + fqn="test_module", + args=[], + context=TrainPipelineContext(), + default_stream=MagicMock(), + dist_stream=MagicMock(), + ) + # self-check - we want the state dict be the same between vanilla model and "rewritten model" + self.assertDictEqual(model.state_dict(), rewritten_model.state_dict()) + state_dict = rewritten_model.state_dict() + self.assertEqual(list(state_dict.keys()), ["test_module.weight"]) + + def _create_model_for_snapshot_test( + self, source_model_type: ModelType + ) -> torch.nn.Module: + if source_model_type == ModelType.VANILLA: + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=10, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + postproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(postproc_module=postproc_module) + model.to_empty(device=self.device) + return model + elif source_model_type == ModelType.SHARDED: + model = self._create_model_for_snapshot_test(ModelType.VANILLA) + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, + ShardingType.TABLE_WISE.value, + EmbeddingComputeKernel.FUSED.value, + {}, + ) + return sharded_model + elif source_model_type == ModelType.PIPELINED: + model = self._create_model_for_snapshot_test(ModelType.SHARDED) + _rewrite_model( + model=model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + pipeline_postproc=True, + ) + return model + else: + raise ValueError(f"Unknown model type {source_model_type}") + + def _test_restore_from_snapshot( + self, source_model_type: ModelType, recipient_model_type: ModelType + ) -> None: + source_model = self._create_model_for_snapshot_test(source_model_type) + recipient_model = self._create_model_for_snapshot_test(recipient_model_type) + + # self-check - we want the state dict be the same between source and recipient + # although this is not strictly necessary + # Asserting only on keys since the asserting on entire state dict fails with + # "Boolean value of Tensor with more than one value is ambiguous" (not sure why) + self.assertEqual( + source_model.state_dict().keys(), recipient_model.state_dict().keys() + ) + + state_dict = source_model.state_dict() + self.assertTrue( + f"postproc_module.{TestNegSamplingModule.TEST_BUFFER_NAME}" + in state_dict.keys() + ) + + missing_keys, unexpected_keys = recipient_model.load_state_dict(state_dict) + # if both are empty, restoring the state dict was successful + self.assertEqual(missing_keys, []) + self.assertEqual(unexpected_keys, []) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_restore_from_snapshot(self) -> None: + # makeshift parameterized test - to avoid introducing new dependencies + variants = [ + # Self-consistency checks - model should be able to load it's own state + (ModelType.VANILLA, ModelType.VANILLA), + (ModelType.SHARDED, ModelType.SHARDED), + (ModelType.PIPELINED, ModelType.PIPELINED), + # Production case - saved from pipelined, restored to sharded + (ModelType.PIPELINED, ModelType.SHARDED), + # Nice-to-haves: + (ModelType.SHARDED, ModelType.PIPELINED), + (ModelType.VANILLA, ModelType.PIPELINED), + (ModelType.VANILLA, ModelType.SHARDED), + # Won't work - restoring sharded/pipelined into vanilla fails with + # "'Parameter' object has no attribute 'local_shards'" + # ... which is totally expected, as vanilla model is not sharded + # (ModelType.SHARDED, ModelType.VANILLA), + # (ModelType.PIPELINED, ModelType.VANILLA), + ] + for source_model_type, recipient_model_type in variants: + self._test_restore_from_snapshot(source_model_type, recipient_model_type) + + @parameterized.expand( + [ + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name="id_list_features", + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 0, + ["id_list_features", "id_score_list_features"], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name=None, + ), + ], + 2, + [], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 1, + ["id_score_list_features"], + ), + ] + ) + def test_build_args_kwargs( + self, + fwd_args: List[ArgInfo], + args_len: int, + kwarges_keys: List[str], + ) -> None: + args, kwargs = _build_args_kwargs("initial_input", fwd_args) + self.assertEqual(len(args), args_len) + self.assertEqual(list(kwargs.keys()), kwarges_keys) + + +class TestUtils(unittest.TestCase): + def test_get_node_args_helper_call_module_kjt(self) -> None: + graph = torch.fx.Graph() + kjt_args = [] + + kjt_args.append( + torch.fx.Node(graph, "values", "placeholder", "torch.Tensor", (), {}) + ) + kjt_args.append( + torch.fx.Node(graph, "lengths", "placeholder", "torch.Tensor", (), {}) + ) + kjt_args.append( + torch.fx.Node( + graph, "weights", "call_module", "PositionWeightedModule", (), {} + ) + ) + + kjt_node = torch.fx.Node( + graph, + "keyed_jagged_tensor", + "call_function", + KeyedJaggedTensor, + tuple(kjt_args), + {}, + ) + + num_found = 0 + _, num_found = _get_node_args( + MagicMock(), kjt_node, set(), TrainPipelineContext(), False + ) + + # Weights is call_module node, so we should only find 2 args unmodified + self.assertEqual(num_found, len(kjt_args) - 1) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py new file mode 100644 index 000000000..4685fae9c --- /dev/null +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -0,0 +1,1771 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import abc +import contextlib +import logging +from collections import deque +from dataclasses import dataclass +from typing import ( + Any, + Callable, + cast, + ContextManager, + Deque, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + Union, +) + +import torch +from torch.autograd.profiler import record_function +from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable +from torchrec.distributed.model_parallel import ShardedModule +from torchrec.distributed.train_pipeline.utils import ( + _override_input_dist_forwards, + _pipeline_detach_model, + _prefetch_embeddings, + _rewrite_model, + _start_data_dist, + _start_embedding_lookup, + _to_device, + _wait_for_batch, + _wait_for_events, + DataLoadingThread, + EmbeddingPipelinedForward, + EmbeddingTrainPipelineContext, + In, + Out, + PipelinedForward, + PipelinedPostproc, + PipelineStage, + PrefetchPipelinedForward, + PrefetchTrainPipelineContext, + RunnableType, + StageOut, + StageOutputWithEvent, + TrainPipelineContext, +) +from torchrec.distributed.types import Awaitable +from torchrec.pt2.checks import is_torchdynamo_compiling +from torchrec.pt2.utils import default_pipeline_input_transformer +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Pipelineable + +logger: logging.Logger = logging.getLogger(__name__) + +# This is required to support older torch package export for older models +try: + from torchrec.distributed.comm_ops import torchrec_use_sync_collectives +except ImportError: + logger.warning("torchrec_use_sync_collectives is not available") + +if not torch._running_with_deploy(): + torch.ops.import_module("fbgemm_gpu.sparse_ops") + + +class ModelDetachedException(Exception): + pass + + +class TrainPipeline(abc.ABC, Generic[In, Out]): + @abc.abstractmethod + def progress(self, dataloader_iter: Iterator[In]) -> Out: + pass + + +@dataclass +class TorchCompileConfig: + """ + Configs for torch.compile + + fullgraph: bool = False, whether to compile the whole graph or not + dynamic: Optional[bool] = None, whether to use dynamic shapes or not, if None, automatic_dynamic_shapes will apply + backend: str = "inductor", which compiler to use (either inductor or aot) + compile_on_iter: int = 3, compile the model on which iteration + this is useful when we want to profile the first few iterations of training + and then start using compiled model from iteration #3 onwards + """ + + fullgraph: bool = False + dynamic: Optional[bool] = None + backend: str = "inductor" + compile_on_iter: int = 3 + + +class TrainPipelineBase(TrainPipeline[In, Out]): + """ + This class runs training iterations using a pipeline of two stages, each as a CUDA + stream, namely, the current (default) stream and `self._memcpy_stream`. For each + iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU + memory, and the default stream runs forward, backward, and optimization. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + custom_model_fwd: Optional[ + Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] + ] = None, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._memcpy_stream: Optional[torch.Stream] = ( + torch.get_device_module(device).Stream() + if device.type in ["cuda", "mtia"] + else None + ) + + # pyre-ignore + self._stream_context = ( + torch.get_device_module(self._device).stream + if self._device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + self._cur_batch: Optional[In] = None + self._connected = False + self._data_iter_stopped = False + + def _reset_data_iter(self) -> None: + self._connected = False + self._data_iter_stopped = False + self._cur_batch = None + + def _connect(self, dataloader_iter: Iterator[In]) -> None: + cur_batch = next(dataloader_iter) + self._cur_batch = cur_batch + if cur_batch is not None: + with self._stream_context(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + self._connected = True + + def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: + with record_function("## next_batch ##"): + try: + next_batch = next(dataloader_iter) + except StopIteration: + self._data_iter_stopped = True + return None + + return next_batch + + def _wait_for_batch(self, cur_batch: In) -> None: + with record_function("## wait_for_batch ##"): + _wait_for_batch(cur_batch, self._memcpy_stream) + + def _backward(self, losses: torch.Tensor) -> None: + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + def _copy_batch_to_gpu(self, cur_batch: In) -> None: + with record_function("## copy_batch_to_gpu ##"): + with self._stream_context(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._connected: + self._connect(dataloader_iter) + if self._data_iter_stopped: + raise StopIteration() + + # Fetch next batch, if depleted, raise at start of next progress + next_batch = self._next_batch(dataloader_iter) + cur_batch = self._cur_batch + + # for exhaustive data iter, some ranks will first depletes data, + # but we still need progress the train pipeline for other ranks; + # cur_batch could be None + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + if cur_batch is not None: + self._wait_for_batch(cur_batch) + + # model will need to handle if cur_batch is empty; this is needed if there's + # communicative ops + with record_function("## forward ##"): + (losses, output) = self._model(cur_batch) + + if self._model.training: + self._backward(losses) + + # Copy the next batch to GPU + self._cur_batch = cur_batch = next_batch + if cur_batch is not None: + self._copy_batch_to_gpu(cur_batch) + + # Update + if self._model.training: + with record_function("## optimizer ##"): + self._optimizer.step() + + return output + + +class TrainPipelinePT2(TrainPipelineBase[In, Out]): + """ + This pipeline uses PT2 compiler to compile the model and run it in a single stream (default) + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where the model is run + compile_configs (TorchCompileConfig): configs for compling the model + pre_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute before compiling the model + post_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute after compiling the model + input_transformer (Callable[[In], In]): transforms the input before passing it to the model. + This is useful when we want to transform KJT parameters for PT2 tracing + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + compile_configs: Optional[TorchCompileConfig] = None, + pre_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, + post_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, + input_transformer: Optional[Callable[[In], In]] = None, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._compile_configs: TorchCompileConfig = ( + compile_configs or TorchCompileConfig() + ) + self._pre_compile_fn = pre_compile_fn + self._post_compile_fn = post_compile_fn + # pyre-ignore + self._input_transformer = ( + input_transformer or default_pipeline_input_transformer + ) + self._iter = 0 + self._cur_batch: Optional[In] = None + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if self._iter == 0: + # Turn on sync collectives for PT2 pipeline. + # To have similar logic between compiled/graph_break ranks. + # TODO(ivankobzarev): Call torchrec.distributed.comm_ops.set_use_sync_collectives(True) when torch package issue on import of comm_ops is fixed + pass + + cc = self._compile_configs + + with record_function("## load_batch ##"): + cur_batch = next(dataloader_iter) + + with record_function("## copy_batch_to_gpu ##"): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False) + + # Input transformer here is used also for pt2 hints to compiler, that should happen on exact object passed to model.compile. + # Do not move it before _to_device + if self._input_transformer: + self._cur_batch = self._input_transformer(self._cur_batch) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## forward ##"): + if self._iter == cc.compile_on_iter: + logger.info("Compiling model...") + if self._pre_compile_fn: + self._pre_compile_fn(self._model) + + # Mandatory dynamo configuration for Torchrec PT2 compilation + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( + True + ) + torch._dynamo.config.skip_torchrec = False + + # Importing only before compilation to not slow-done train_pipelines import + torch.ops.import_module("fbgemm_gpu.sparse_ops") + + self._model.compile( + fullgraph=cc.fullgraph, dynamic=cc.dynamic, backend=cc.backend + ) + if self._post_compile_fn: + self._post_compile_fn(self._model) + + (losses, output) = self._model(self._cur_batch) + self._iter += 1 + + if self._model.training: + with record_function("## backward ##"): + torch.sum(losses).backward() + + with record_function("## optimizer ##"): + self._optimizer.step() + + return output + + +class TrainPipelineSparseDist(TrainPipeline[In, Out]): + """ + This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with + forward and backward. This helps hide the all2all latency while preserving the + training forward / backward ordering. + + stage 3: forward, backward - uses default CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + # The PipelinedForward class that is used in _rewrite_model + _pipelined_forward_type = PipelinedForward + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + context_type: Type[TrainPipelineContext] = TrainPipelineContext, + # keep for backward compatibility + pipeline_postproc: bool = False, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._execute_all_batches = execute_all_batches + self._apply_jit = apply_jit + + if device.type == "cuda": + # use two data streams to support two concurrent batches + # Dynamo does not support cuda stream specificaiton, + # this freedom is left for compiler pipelining optimizations. + assert ( + not is_torchdynamo_compiling() + ), "Train Pipelines rely on cuda streams, which is not supported by Dynamo" + + # pyre-ignore + self._stream_context = ( + torch.get_device_module(self._device).stream + if self._device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + + self._memcpy_stream: Optional[torch.Stream] = ( + (torch.get_device_module(device).Stream(priority=-1)) + if device.type in ["cuda", "mtia"] + else None + ) + self._data_dist_stream: Optional[torch.Stream] = ( + (torch.get_device_module(device).Stream(priority=-1)) + if device.type in ["cuda", "mtia"] + else None + ) + + # pyre-ignore + self._original_forwards: List[Callable[..., Any]] = [] + + self._original_kjt_dist_forwards: List[ + Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] + ] = [] + + self._model_attached = True + self._pipeline_postproc = pipeline_postproc + + self._next_index: int = 0 + self.contexts: Deque[TrainPipelineContext] = deque() + self._pipelined_modules: List[ShardedModule] = [] + self._pipelined_postprocs: List[PipelinedPostproc] = [] + self.batches: Deque[Optional[In]] = deque() + self._dataloader_iter: Optional[Iterator[In]] = None + self._dataloader_exhausted: bool = False + self._context_type: Type[TrainPipelineContext] = context_type + + self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = ( + custom_model_fwd if custom_model_fwd else model + ) + + # DEPRECATED FIELDS + self._batch_i: Optional[In] = None + self._batch_ip1: Optional[In] = None + self._batch_ip2: Optional[In] = None + self._context: TrainPipelineContext = context_type(version=0) + + def detach(self) -> torch.nn.Module: + """ + Detaches the model from sparse data dist (SDD) pipeline. A user might want to get + the original model back after training. The original model.forward was previously + modified by the train pipeline. for more please see: + https://github.com/pytorch/torchrec/pull/2076 + + To use the pipeline after detaching the model, pipeline.attach(model) + needs to be called. + Inflight batches are kept so pipeline.progress(data_iter) can be resumed normally. + + Returns the original model. + """ + if self._pipelined_modules: + _pipeline_detach_model( + model=self._model, + pipelined_modules=self._pipelined_modules, + original_forwards=self._original_forwards, + original_kjt_dist_forwards=self._original_kjt_dist_forwards, + pipelined_postprocs=self._pipelined_postprocs, + ) + + self._model_attached = False + return self._model + + def attach( + self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True + ) -> None: + """ + should be used with detach function. these functions should only be used from user code, + when user want to switch the train pipeline. for more please see: + https://github.com/pytorch/torchrec/pull/2076 + """ + if model: + self._model = model + + self._model_attached = True + if self.contexts: + self._pipeline_model( + batch=self.batches[0] if sparse_dist else None, + context=self.contexts[0], + pipelined_forward=self._pipelined_forward_type, + ) + else: + # attaching the model after end of train pipeline + # model rewrite for SDD needs context but self.contexts is empty + # reset _pipelined_modules so _fill_pipeline will rewrite model on progress() + self._pipelined_modules = [] + self._pipelined_postprocs = [] + + def _set_module_context(self, context: TrainPipelineContext) -> None: + """ + pipelined modules are the TorchRec's sparse modules like shardedEBC, shardedEC, etc. + the forward function is swapped with a PipelinedForward in the _rewrite_model call. + The PipelinedForward needs a context to correctly perform the forward behavior. + please check PipelinedForward for details. + """ + for module in self._pipelined_modules: + module.forward.set_context(context) + + for postproc_module in self._pipelined_postprocs: + # This ensures that next iter model fwd uses cached results + postproc_module.set_context(context) + + def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool: + """ + load a data batch from dataloader, and copy it from cpu to gpu + also create the context for this batch. + """ + batch, context = self.copy_batch_to_gpu(dataloader_iter) + if batch is None: + return False + self.batches.append(batch) + # pyre-ignore [6] + self.contexts.append(context) + + return True + + def dequeue_batch(self) -> None: + """ + remove a processed batch from the batch queue, also set the module context if applicable + """ + self.batches.popleft() + self.contexts.popleft() + + # update PipelinedForward context to match next forward pass + if len(self.batches) >= 1: + self._set_module_context(self.contexts[0]) + + def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + """ + This function is called in self.progress (one of the main APIs for running train pipeline) + Here we assume the max pipelined len(batches) == 2 (capacity), which will be the most common + scenario during the full training job, when this function is effectively doing nothing. + There would only be two other scenarios: + len(batches) == 0: + initialize the pipeline, fill in two batches, start input_dist for the first batch. + len(batches) == 1: + dataloader_iter stops, the last batch, do nothing + """ + + # pipeline is already filled with max capacity (2) + if len(self.batches) >= 2: + return + + # executes last batch in pipeline, when there is only one batch in the pipeline + # TODO: this _execute_all_batches doesn't really work here D43546239. it will + # just throw an exception at copy_to_gpu when the dataloader is exhausted + if self.batches and self._execute_all_batches: + return + + # batch i, data (batch) and context + if not self.enqueue_batch(dataloader_iter): + return + + # modify the (sharded) sparse module forward, and invoke the first part of input_dist + self._init_pipelined_modules( + # pyre-ignore [6] + self.batches[0], + self.contexts[0], + self._pipelined_forward_type, + ) + # doing the second part of input_dist, the first part is invoked in _init_pipelined_modules + self.wait_sparse_data_dist(self.contexts[0]) + + # batch i+1 + if not self.enqueue_batch(dataloader_iter): + return + + def _wait_for_batch(self) -> None: + with record_function("## wait_for_batch ##"): + _wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream) + + def _backward(self, losses: torch.Tensor) -> None: + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity): + batches[0]: current batch, for emb_lookup, output_dist, and fwd/bwd/opt (expecting input_dist) + batches[1]: next batch, for input_dist (expecting copied to device) + batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter) + """ + + # attach the model just in case the user forgets to call it, especially when the user + # pauses the pipeline.progress and detach the model for other purpose. + if not self._model_attached: + self.attach(self._model) + + # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty + self.fill_pipeline(dataloader_iter) + + # here is the expected stop after exhausting all batches + if not self.batches: + raise StopIteration + + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + # wait for batches[0] being available on device, this should always be completed since + # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check + self._wait_for_batch() + + if len(self.batches) >= 2: + # invoke splits all_to_all comms (first part of input_dist) + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + + # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here + self.enqueue_batch(dataloader_iter) + + # forward + with record_function("## forward ##"): + losses, output = self._model_fwd(self.batches[0]) + + if len(self.batches) >= 2: + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) + self.wait_sparse_data_dist(self.contexts[1]) + + if self._model.training: + # backward + self._backward(losses) + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self.dequeue_batch() + return output + + def _create_context(self) -> TrainPipelineContext: + context = self._context_type(index=self._next_index, version=1) + self._next_index += 1 + return context + + def _pipeline_model( + self, + batch: Optional[In], + context: TrainPipelineContext, + pipelined_forward: Type[PipelinedForward] = PipelinedForward, + ) -> None: + ( + self._pipelined_modules, + self._model, + self._original_forwards, + self._pipelined_postprocs, + _, + ) = _rewrite_model( + model=self._model, + context=context, + dist_stream=self._data_dist_stream, + default_stream=torch.get_device_module(self._device).current_stream(), + batch=batch, + apply_jit=self._apply_jit, + pipelined_forward=pipelined_forward, + pipeline_postproc=self._pipeline_postproc, + ) + # initializes input dist, so we can override input dist forwards + self.start_sparse_data_dist(batch, context) + self._original_kjt_dist_forwards = _override_input_dist_forwards( + self._pipelined_modules + ) + + def _init_pipelined_modules( + self, + batch: In, + context: TrainPipelineContext, + pipelined_forward: Type[PipelinedForward] = PipelinedForward, + ) -> None: + """ + Retrieves the pipelined modules after overriding their forwards, initializes the + modules' input dists, and overrides the input dist forwards to support fusing + the splits collective in the input dist. + """ + if self._pipelined_modules: + self._set_module_context(context) + self.start_sparse_data_dist(batch, context) + return + + self._pipeline_model(batch, context, pipelined_forward) + + def copy_batch_to_gpu( + self, + dataloader_iter: Iterator[In], + ) -> Tuple[Optional[In], Optional[TrainPipelineContext]]: + """ + Retrieves batch from dataloader and moves it to the provided device. + + Raises: + StopIteration: if the dataloader iterator is exhausted; unless + `self._execute_all_batches=True`, then returns None. + """ + context = self._create_context() + with record_function(f"## copy_batch_to_gpu {context.index} ##"): + with self._stream_context(self._memcpy_stream): + batch = self._next_batch(dataloader_iter) + if batch is not None: + batch = _to_device(batch, self._device, non_blocking=True) + elif not self._execute_all_batches: + raise StopIteration + return batch, context + + def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: + """ + Retrieves next batch from dataloader and prevents calling `next` on an already + exhausted dataloader, which can cause hanging. + """ + if dataloader_iter is not self._dataloader_iter: + self._dataloader_iter = dataloader_iter + self._dataloader_exhausted = False + + if self._dataloader_exhausted: + batch = None + else: + with record_function("## next_batch ##"): + batch = next(dataloader_iter, None) + if batch is None: + self._dataloader_exhausted = True + return batch + + def start_sparse_data_dist( + self, batch: Optional[In], context: TrainPipelineContext + ) -> None: + """ + Waits for batch to finish getting copied to GPU, then starts the input dist. + """ + if batch is None: + return + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_batch(batch, self._memcpy_stream) + + original_contexts = [p.get_context() for p in self._pipelined_postprocs] + + # Temporarily set context for next iter to populate cache + for postproc_mod in self._pipelined_postprocs: + postproc_mod.set_context(context) + + _start_data_dist(self._pipelined_modules, batch, context) + + # Restore context for model fwd + for module, context in zip( + self._pipelined_postprocs, original_contexts + ): + module.set_context(context) + + def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: + """ + Waits on the input dist splits requests to get the input dist tensors requests, + and populates the context with them. + """ + with record_function(f"## wait_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + for names, awaitable in context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + context.input_dist_tensors_requests[name] = request + context.input_dist_splits_requests.clear() + context.fused_splits_awaitables.clear() + + def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: + """ + DEPRECATED: exists for backward compatibility on TrainPipelineContext.version 0 + """ + self._set_module_context(self._context) + batch, _ = self.copy_batch_to_gpu(dataloader_iter) + return batch + + def _start_sparse_data_dist(self, batch: Optional[In]) -> None: + """ + DEPRECATED: exists for backward compatibility + Waits for batch to finish getting copied to GPU, then starts the input dist. + """ + self._set_module_context(self._context) + self.start_sparse_data_dist(batch, self._context) + + def _wait_sparse_data_dist(self) -> None: + """ + DEPRECATED: exists for backward compatibility + Waits on the input dist splits requests to get the input dist tensors requests, + and populates the context with them. + """ + self._set_module_context(self._context) + with record_function("## wait_sparse_data_dist ##"): + with self._stream_context(self._data_dist_stream): + self._context.module_contexts = ( + self._context.module_contexts_next_batch.copy() + ) + self._context.input_dist_tensors_requests.clear() + for names, awaitable in self._context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + self._context.input_dist_tensors_requests[name] = request + + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + """ + DEPRECATED: exists for backward compatibility + """ + # pipeline is already filled + if self._batch_i and self._batch_ip1: + return + # executes last batch in pipeline + if self._batch_i and self._execute_all_batches: + return + + # batch 1 + self._batch_i = self._copy_batch_to_gpu(dataloader_iter) + if self._batch_i is None: + raise StopIteration + + self._init_pipelined_modules(self._batch_i, self._context) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + + # batch 2 + self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) + + +class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): + """ + Novel method for RecSys model training by leveraging "Semi-Synchronous" training, + where the model is still synchronous but each batch prediction is calculated + on parameters which were last updated B-2, instead of the batch prior (ie. B-1). This + allows the Embedding All-to-All from B to be fully overlapped with forward pass of B-1; dramatically + improving peak training performance. + + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX. + stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync" + training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync) + """ + + # The PipelinedForward class that is used in _rewrite_model + _pipelined_forward_type = EmbeddingPipelinedForward # pyre-ignore + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + start_batch: int = 900, + stash_gradients: bool = False, + pipeline_postproc: bool = True, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, + strict: bool = False, + ) -> None: + super().__init__( + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + context_type=EmbeddingTrainPipelineContext, + pipeline_postproc=pipeline_postproc, + custom_model_fwd=custom_model_fwd, + ) + self._start_batch = start_batch + self._stash_gradients = stash_gradients + logger.debug(f"Starting semi-sync run at batch: {self._start_batch}") + self._gradients: Dict[str, torch.Tensor] = {} + self._strict = strict + + def _grad_swap(self) -> None: + for name, param in self._model.named_parameters(): + grad = self._gradients.get(name, None) + if param.grad is not None: + self._gradients[name] = param.grad.clone() + param.grad = grad + + def _validate_optimizer(self) -> None: + for pipelined_module in self._pipelined_modules: + pipelined_params = set(pipelined_module.parameters()) + for group in self._optimizer.param_groups: + if not set(group["params"]).isdisjoint(pipelined_params): + error_msg = f"SemiSync pipelined {type(pipelined_module)} and optimizer share parameters. This could lead to convergence issues." + if self._strict: + raise Exception(error_msg) + else: + logger.warning(error_msg) + + def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + # pipeline is already filled + if len(self.batches) >= 3: + return + # executes last batch in pipeline + if self.batches and self._execute_all_batches: + return + + # batch i + if not self.enqueue_batch(dataloader_iter): + return + + self._init_pipelined_modules( + # pyre-ignore [6] + self.batches[0], + self.contexts[0], + # pyre-ignore [6] + self._pipelined_forward_type, + ) + self.wait_sparse_data_dist(self.contexts[0]) + self._validate_optimizer() + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) + + # batch i+1 + if not self.enqueue_batch(dataloader_iter): + return + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + self.wait_sparse_data_dist(self.contexts[1]) + + # batch i+2 + if not self.enqueue_batch(dataloader_iter): + return + + def is_semi_sync(self) -> bool: + if len(self.batches) >= 1: + # pyre-ignore [58] + return self.contexts[0].index >= self._start_batch + return False + + def _mlp_optimizer_step(self, current_batch: int) -> None: + # special case: not all optimizers support optim.step() on null gradidents + if current_batch == self._start_batch and self._stash_gradients: + return + self._optimizer.step() + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + # attach the model just in case the user forgets to call it, especially when the user + # pauses the pipeline.progress and detach the model for other purpose. + if not self._model_attached: + self.attach(self._model) + + self.fill_pipeline(dataloader_iter) + if not self.batches: + raise StopIteration + + if len(self.batches) >= 3: + self.start_sparse_data_dist( + self.batches[2], + self.contexts[2], + ) + + batch, context = self.batches[0], self.contexts[0] + is_semi_sync = context.index is not None and context.index >= self._start_batch + iteration: int = context.index or 0 + losses, output = self._mlp_forward(cast(In, batch), context) + + # After this point, pipelined postproc/module forward won't be called + # so we can advance their contexts to the context of the next batch already + # and also pop batch and context from self.batches and self.contexts + self.dequeue_batch() + + # batch no longer needed - delete to free up memory + del batch + + # cached postproc fwd results no longer needed - delete to free up memory + del context.postproc_fwd_results + + # batch i+3 + self.enqueue_batch(dataloader_iter) + + if len(self.batches) >= 1 and is_semi_sync: + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) + + if len(self.batches) >= 2: + self.wait_sparse_data_dist(self.contexts[1]) + + if self._model.training: + with record_function(f"## backward {iteration} ##"): + torch.sum(losses, dim=0).backward() + with record_function(f"## emb_backward {iteration} ##"): + # pyre-ignore [6] + self.embedding_backward(context) + + del context # context is no longer needed, deleting to free up memory + + with record_function(f"## optimizer {iteration - 1} ##"): + if is_semi_sync and self._stash_gradients: + self._grad_swap() + self._mlp_optimizer_step(iteration) + + with record_function(f"## zero_grad {iteration - 1} ##"): + self._optimizer.zero_grad() + else: + del context + + if len(self.batches) >= 1 and not is_semi_sync: + torch.cuda.synchronize() # needed to avoid race condition + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) + + return output + + def _mlp_forward( + self, batch: In, context: TrainPipelineContext + ) -> Tuple[torch.Tensor, Out]: + with record_function(f"## forward {context.index} ##"): + _wait_for_events( + batch, context, torch.get_device_module(self._device).current_stream() + ) + return self._model_fwd(batch) + + def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: + assert len(context.embedding_features) == len(context.embedding_tensors) + for emb_tensors, embedding_features, detached_emb_tensors in zip( + context.embedding_tensors, + context.embedding_features, + context.detached_embedding_tensors, + ): + grads = [tensor.grad for tensor in detached_emb_tensors] + """ + Some embeddings may never get used in the final loss computation, + so the grads will be `None`. If we don't exclude these, it will fail + with error: "grad can be implicitly created only for scalar outputs" + Alternatively, if the tensor has only 1 element, pytorch can still + figure out how to do autograd + """ + embs_to_backprop, grads_to_use, invalid_features = [], [], [] + assert len(embedding_features) == len(emb_tensors) + for features, tensor, grad in zip(embedding_features, emb_tensors, grads): + if tensor.numel() == 1 or grad is not None: + embs_to_backprop.append(tensor) + grads_to_use.append(grad) + else: + if isinstance(features, str): + invalid_features.append(features) + elif isinstance(features, Iterable): + invalid_features.extend(features) + else: + invalid_features.append(features) + if invalid_features and context.index == 0: + logger.warning( + f"SemiSync, the following features have no gradients: {invalid_features}" + ) + torch.autograd.backward(embs_to_backprop, grads_to_use) + + def copy_batch_to_gpu( + self, + dataloader_iter: Iterator[In], + ) -> Tuple[Optional[In], Optional[TrainPipelineContext]]: + context = None + with record_function(f"## copy_batch_to_gpu {self._next_index} ##"): + with self._stream_context(self._memcpy_stream): + batch = self._next_batch(dataloader_iter) + if batch is not None: + batch = _to_device(batch, self._device, non_blocking=True) + context = self._create_context() + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) + return batch, context + + def extract_model_input_from_batch(self, batch: In) -> Pipelineable: + return batch + + def start_sparse_data_dist( + self, + batch: Optional[In], + context: TrainPipelineContext, + ) -> None: + """ + Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version. + """ + if batch is None: + return + + # Temporarily set context for next iter to populate cache + original_contexts = [p.get_context() for p in self._pipelined_postprocs] + for postproc_mod in self._pipelined_postprocs: + postproc_mod.set_context(context) + + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_events(batch, context, self._data_dist_stream) + model_input = self.extract_model_input_from_batch(batch) + _start_data_dist(self._pipelined_modules, model_input, context) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) + + # Restore context for model forward + for module, context in zip(self._pipelined_postprocs, original_contexts): + module.set_context(context) + + def start_embedding_lookup( + self, + batch: Optional[In], + context: EmbeddingTrainPipelineContext, + ) -> None: + """ + Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion. + """ + if batch is None: + return + + with record_function(f"## start_embedding_lookup {context.index} ##"): + current_stream = torch.get_device_module(self._device).current_stream() + _wait_for_events(batch, context, current_stream) + for i, module in enumerate(self._pipelined_modules): + _start_embedding_lookup( + module, + context, + source_stream=self._data_dist_stream, + target_stream=current_stream, + stream_context=self._stream_context, + ) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) + + +class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): + """ + This pipeline overlaps device transfer, `ShardedModule.input_dist()`, and cache + prefetching with forward and backward. This helps hide the all2all latency while + preserving the training forward / backward ordering. + + stage 4: forward, backward - uses default CUDA stream + stage 3: prefetch - uses prefetch CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, prefetch, + and forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + # The PipelinedForward class that is used in _rewrite_model + _pipelined_forward_type = PrefetchPipelinedForward # pyre-ignore + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + pipeline_postproc: bool = True, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, + ) -> None: + super().__init__( + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + context_type=PrefetchTrainPipelineContext, + pipeline_postproc=pipeline_postproc, + custom_model_fwd=custom_model_fwd, + ) + self._context = PrefetchTrainPipelineContext(version=0) + self._prefetch_stream: Optional[torch.Stream] = ( + (torch.get_device_module(device).Stream()) + if self._device.type in ["cuda", "mtia"] + else None + ) + self._default_stream: Optional[torch.Stream] = ( + (torch.get_device_module(self._device).Stream()) + if self._device.type in ["cuda", "mtia"] + else None + ) + self._batch_ip3: Optional[In] = None + + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + # pipeline is already filled + if self._batch_i and self._batch_ip1 and self._batch_ip2: + return + # executes last batch in pipeline + if self._execute_all_batches and (self._batch_i or self._batch_ip1): + return + + # batch 1 + self._batch_i = self._copy_batch_to_gpu(dataloader_iter) + if self._batch_i is None: + raise StopIteration + + self._init_pipelined_modules( + self._batch_i, + self._context, + # pyre-ignore + self._pipelined_forward_type, + ) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + self._prefetch(self._batch_i) + + # batch 2 + self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) + self._start_sparse_data_dist(self._batch_ip1) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + self._fill_pipeline(dataloader_iter) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## wait_for_batch ##"): + _wait_for_batch(cast(In, self._batch_i), self._prefetch_stream) + + self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) + + self._wait_sparse_data_dist() + # forward + with record_function("## forward ##"): + losses, output = self._model_fwd(self._batch_i) + + self._prefetch(self._batch_ip1) + + if self._model.training: + # backward + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self._start_sparse_data_dist(self._batch_ip2) + + self._batch_i = self._batch_ip1 + self._batch_ip1 = self._batch_ip2 + + return output + + def _prefetch(self, batch: Optional[In]) -> None: + """ + Waits for input dist to finish, then prefetches data. + """ + if batch is None: + return + self._context.module_input_post_prefetch.clear() + self._context.module_contexts_post_prefetch.clear() + + with record_function("## sharded_module_prefetch ##"): + with self._stream_context(self._prefetch_stream): + batch.record_stream( + torch.get_device_module(self._device).current_stream() + ) + data_per_pipelined_module = _prefetch_embeddings( + batch, + self._context, + self._pipelined_modules, + self._device, + self._stream_context, + self._data_dist_stream, + self._default_stream, + ) + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + data = data_per_pipelined_module[forward._name] + self._context.module_input_post_prefetch[forward._name] = data + self._context.module_contexts_post_prefetch[forward._name] = ( + self._context.module_contexts.pop(forward._name) + ) + + +class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]): + """ + This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with + forward. This helps hide the all2all latency. We use a background thread to + perform device transfer to further reduce latency. + + stage 2: forward- uses default CUDA stream + stage 1: ShardedModule.input_dist() - uses data_dist CUDA stream + background: device transfer - uses memcpy CUDA stream + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + """ + + # The PipelinedForward class that is used in _rewrite_model + _pipelined_forward_type = PipelinedForward + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + apply_jit: bool = False, + ) -> None: + super().__init__(model, optimizer, device, True, apply_jit) + self._batch_loader: Optional[DataLoadingThread[In]] = None + + def __del__(self) -> None: + if self._batch_loader is not None: + self._batch_loader.stop() + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._batch_loader: + self._batch_loader = DataLoadingThread( + device=self._device, + dataloader_iter=dataloader_iter, + to_device_non_blocking=True, + memcpy_stream_priority=-1, + memcpy_stream=self._memcpy_stream, + ) + self._batch_loader.start() + + # batch 0 + # pyre-ignore [16] + batch = self._batch_loader.get_next_batch() + if batch is None: + raise StopIteration + self.batches.append(batch) + self.contexts.append(self._create_context()) + + self._init_pipelined_modules( + # pyre-ignore + self.batches[0], + self.contexts[0], + self._pipelined_forward_type, + ) + self.start_sparse_data_dist(self.batches[0], self.contexts[0]) + self.wait_sparse_data_dist(self.contexts[0]) + + batch = self._batch_loader.get_next_batch() + if batch is not None: + self.batches.append(batch) + self.contexts.append(self._create_context()) + + if len(self.batches) == 0: + raise StopIteration + + with record_function("## wait_for_batch ##"): + _wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream) + + if len(self.batches) >= 2: + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + + # forward + with record_function("## forward ##"): + losses, output = cast( + Tuple[torch.Tensor, Out], self._model(self.batches[0]) + ) + + if len(self.batches) >= 2: + self.wait_sparse_data_dist(self.contexts[1]) + self.dequeue_batch() + + return output + + +class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]): + """ + StagedTrainPipeline orchestrates the pipelined execution of its constituent stages + from inputs of `dataloader_iter`. Namely scheduling the execution of stages before + model forward. + + NOTE: the SDD stage needs to be the final stage of the pipeline so that the + `ShardedModule` forward can properly consume the SDD output. + + Calling progress on a `StagedTrainPipeline` provides an output that is equivalent to + calling each of the pipeline stages in order. + + In the example below a fully synchronous will expose the `data_copy` and + `gpu_postproc` calls. After pipelining, the `data_copy` of batch i+2 can be + overlapped with the `gpu_postproc` of batch i+1 and the main model processing of + batch i. + + Args: + pipeline_stages (List[PipelineStage]): A list of stages to execute. + debug_mode (bool): Whether to enable debug mode. + compute_stream (Optional[torch.cuda.Stream]): The main compute stream in which + model forward is run, usually torch.cuda.default_stream(). Defaults to the + current cuda stream. + on_flush_end (Optional): Callback function that gets invoked after the pipeline + has been flushed. + + Example:: + train_pipeline = StagedTrainPipeline( + pipeline=[ + PipelineStage( + name="data_copy", + runnable=get_h2d_func("cuda"), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="gpu_postproc", + runnable=gpu_postproc, + stream=torch.cuda.Stream(), + ), + ] + ) + + while batch_for_forward := train_pipeline.progress(dataloader_iter): + optimizer.zero_grad() + loss, pred = model(batch_for_forward) + loss.backward() + optimizer.step() + """ + + def __init__( + self, + pipeline_stages: List[PipelineStage], + debug_mode: bool = False, + compute_stream: Optional[Union[torch.cuda.Stream, torch.mtia.Stream]] = None, + on_flush_end: Optional[Callable[[], None]] = None, + ) -> None: + self._pipeline_stages = pipeline_stages + self._debug_mode = debug_mode + self._stage_outputs: List[Optional[StageOutputWithEvent]] = cast( + List[Optional[StageOutputWithEvent]], [None] * len(self._pipeline_stages) + ) + self._initialized = False + self._num_steps = 0 + self._dataloader_iter: Optional[Iterator[In]] = None + self._dataloader_exhausted: bool = False + self._compute_stream: torch.Stream = ( + compute_stream + or torch.get_device_module( + self._pipeline_stages[0].stream.device + ).current_stream() + ) + + # pyre-ignore + self._stream_context = ( + torch.get_device_module(self._compute_stream.device).stream + if self._compute_stream.device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + + self._flushing: bool = False + self.on_flush_end = on_flush_end + + @property + def num_stages(self) -> int: + return len(self._pipeline_stages) + + def _advance(self) -> Optional[StageOutputWithEvent]: + # left shifts all batch results. + out = self._stage_outputs[0] + for idx in range(self.num_stages - 1): + self._stage_outputs[idx] = self._stage_outputs[idx + 1] + self._stage_outputs[-1] = None + return out + + def _run_with_event( + self, + runnable: RunnableType, + event: Optional[torch.Event], + inputs: Optional[In], + stream: torch.Stream, + ) -> StageOutputWithEvent: + if inputs is None: + return (None, None) + with self._stream_context(stream): + # If there is no previous event, data is entering the pipeline + if event is not None: + event.wait(stream) + inputs.record_stream(stream) + + output = runnable(inputs) + new_event = torch.get_device_module(stream.device).Event() + new_event.record(stream) + return (output, new_event) + + def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: + """ + Retrieves next batch from dataloader and prevents calling `next` on an already + exhausted dataloader, which can cause hanging. + """ + if dataloader_iter is not self._dataloader_iter: + self._dataloader_iter = dataloader_iter + self._dataloader_exhausted = False + + if self._dataloader_exhausted or self._flushing: + batch = None + else: + with record_function("## next_batch ##"): + batch = next(dataloader_iter, None) + if batch is None: + self._dataloader_exhausted = True + return batch + + def _run_stage( + self, + batch_offset: int, + stage_idx: int, + dataloader_iter: Iterator[In], + fill: bool = False, + ) -> StageOutputWithEvent: + """ + Each stage of the pipeline MUST have an input and output. If the input is None, + it means there is no more data to process. The stage will short circuit and NOT + execute the runnable. + """ + stage = self._pipeline_stages[stage_idx] + + with record_function( + f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##" + ): + if stage_idx == 0: + batch_to_wait = self._next_batch(dataloader_iter) + event = None + else: + batch_to_wait_with_event = self._stage_outputs[batch_offset] + assert batch_to_wait_with_event is not None + batch_to_wait, event = batch_to_wait_with_event + + new_result = self._run_with_event( + runnable=stage.runnable, + event=event, + inputs=batch_to_wait, + stream=stage.stream, + ) + + self._stage_outputs[batch_offset] = new_result + if self._debug_mode: + logger.info( + f"Running ## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", + ) + + if fill and (fill_callback := stage.fill_callback) is not None: + if self._debug_mode: + logger.info(f"Finished callback for {stage.name}") + fill_callback() + + return new_result + + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + """ + There should always be `self.num_stages` batches in flight. This function + initializes the pipeline by filling it with `self.num_stages` batches. + Intuitively, it does all the stages before the model forward. + + NOTE: + model forward should be executed outside the pipeline in the train loop, + using the output of `progress` as its input. + + For a 3 stage pipeline during `_fill_pipeline`: + batch 0: stages 0, 1, 2 will be run + batch 1: stages 0, 1 will be run + batch 2: stage 0 will be run + batch 3: will start in `progress()` + + In the initial `progress()` + batch 0: model forward will be run + batch 1: stage 2 will be run + batch 2: stage 1 will be run + batch 3: stage 0 will be run + """ + for batch_offset in range(self.num_stages): + stages_to_run = self.num_stages - batch_offset + for stage_idx in range(stages_to_run): + self._run_stage( + batch_offset=batch_offset, + stage_idx=stage_idx, + dataloader_iter=dataloader_iter, + fill=True, + ) + + self._initialized = True + if self._debug_mode: + logger.info("Finished fill pipeline") + + def set_flush(self, flush: bool) -> None: + """ + Sets whether the pipeline should be flushed. + + When the pipeline is in a flushing state, it will stop getting further data from the dataloader and will continue executing the pipeline until all remaining stages are finished. Afterwards, it will invoke a callback and resume the pipeline. + """ + self._flushing = flush + + def flush_end(self) -> None: + self.set_flush(False) + # Ensure pipeline gets filled again + self._initialized = False + + if self.on_flush_end is not None: + self.on_flush_end() + + def progress( + self, + dataloader_iter: Iterator[In], + ) -> Optional[StageOut]: + """ + The pipeline processes data in reverse order, so stage_0 processes the + newest data and stage_n processes the oldest. + + NOTE: + if SDD is enabled it must be the last stage in the pipeline. + + Args: + data_iter (Iterator[In]): An iterator that produces the inputs to + the pipeline. + + Returns: + Optional[StageOut]: Output of the final stage. `None` signifies that the + dataloader iterator is depleted. + """ + if not self._initialized: + self._fill_pipeline(dataloader_iter) + + output_with_event = self._advance() + + if output_with_event is None: + # All data consumed, exit early + return None + + self._num_steps += 1 + + for stage_idx in range(self.num_stages): + stage_output_idx = self.num_stages - 1 - stage_idx + self._run_stage( + batch_offset=stage_output_idx, + stage_idx=stage_idx, + dataloader_iter=dataloader_iter, + ) + + out, event = output_with_event + if event is not None: + # Since model forward() is expected to run outside the pipeline, + # we need to explicitly wait for the last stage to finish + event.wait(self._compute_stream) + out.record_stream(self._compute_stream) + + if out is None and self._flushing: + # We have exhausted all stages due to flushing + self.flush_end() + return self.progress(dataloader_iter) + + return out + + +class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]): + """ + This pipeline clone the TrainPipelineSparseDist, but execute the progress + method within compiled autograd context. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + context_type: Type[TrainPipelineContext] = TrainPipelineContext, + pipeline_postproc: bool = False, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, + ) -> None: + super().__init__( + model, + optimizer, + device, + execute_all_batches, + apply_jit, + context_type, + pipeline_postproc, + custom_model_fwd, + ) + + torch._logging.set_logs(compiled_autograd_verbose=True) + + # it will check this path on model to inject configuration other than + # the default one. + # pyre-fixme[8]: Attribute has type `Dict[str, Union[bool, str]]`; used as + # `Union[Tensor, Module]`. + self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr( + model, + "_compiled_autograd_options", + { + "backend": "inductor", + "dynamic": True, + "fullgraph": True, + }, + ) + torch._dynamo.config.inline_inbuilt_nn_modules = True + torch._dynamo.config.skip_fsdp_hooks = False + torch._functorch.config.recompute_views = True + torch._functorch.config.cse = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ] + self.initialized = False + + def get_compiled_autograd_ctx( + self, + # pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter. + ) -> ContextManager: + # this allows for pipelining + # to avoid doing a sum on None + # when the pipeline is empty + if not self.initialized: + self.initialized = True + return contextlib.nullcontext() + + return torch._dynamo.compiled_autograd._enable( + # pyre-ignore + torch.compile(**self.compiled_autograd_options) + ) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + # attach the model just in case the user forgets to call it, especially when the user + # pauses the pipeline.progress and detach the model for other purpose. + if not self._model_attached: + self.attach(self._model) + + self.fill_pipeline(dataloader_iter) + if not self.batches: + raise StopIteration + + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## wait_for_batch ##"): + _wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream) + + if len(self.batches) >= 2: + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + + # batch i+2 + self.enqueue_batch(dataloader_iter) + + # forward + ctx = self.get_compiled_autograd_ctx() + with ctx, torchrec_use_sync_collectives(), record_function("## forward ##"): + losses, output = self._model_fwd(self.batches[0]) + + if len(self.batches) >= 2: + self.wait_sparse_data_dist(self.contexts[1]) + + if self._model.training: + # backward + ctx = self.get_compiled_autograd_ctx() + with ctx, torchrec_use_sync_collectives(), record_function( + "## backward ##" + ): + torch.sum(losses, dim=0).backward() + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self.dequeue_batch() + return output diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py new file mode 100644 index 000000000..0f8f5b937 --- /dev/null +++ b/torchrec/distributed/train_pipeline/utils.py @@ -0,0 +1,1962 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import itertools +import logging +from collections import defaultdict, OrderedDict +from contextlib import AbstractContextManager +from dataclasses import dataclass, field + +from itertools import chain +from threading import Event, Thread +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import torch +from torch import distributed as dist +from torchrec.distributed.types import LazyAwaitable + +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.fx.immutable_collections import ( + immutable_dict as fx_immutable_dict, + immutable_list as fx_immutable_list, +) +from torch.fx.node import Node +from torch.nn.modules.module import _IncompatibleKeys +from torch.profiler import record_function +from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable +from torchrec.distributed.embedding_sharding import ( + FusedKJTListSplitsAwaitable, + KJTListSplitsAwaitable, + KJTSplitsAllToAllMeta, +) +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule + +from torchrec.distributed.types import Awaitable, LazyNoWait + +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.streamable import Multistreamable, Pipelineable + +logger: logging.Logger = logging.getLogger(__name__) + +import torch + +In = TypeVar("In", bound=Pipelineable) +StageOut = TypeVar("StageOut", bound=Pipelineable) +Out = TypeVar("Out") + +RunnableType = Callable[..., StageOut] +StageOutputWithEvent = Tuple[Optional[StageOut], Optional[torch.Event]] + + +@dataclass +class TrainPipelineContext: + """ + Context information for a `TrainPipelineSparseDist` instance. + + Attributes: + input_dist_splits_requests (Dict[str, Awaitable[Any]]): Stores input dist + requests in the splits awaitable stage, which occurs after starting the + input dist. + input_dist_tensors_requests (Dict[str, Awaitable[Any]]): Stores input dist + requests in the tensors awaitable stage, which occurs after calling `wait()` + on the splits awaitable. + module_contexts (Dict[str, Multistreamable]): Stores module contexts from the + input dist for the current batch. + module_contexts_next_batch (Dict[str, Multistreamable]): Stores module contexts + from the input dist for the next batch. (only for version 0) + fused_splits_awaitables (List[Tuple[List[str], FusedKJTListSplitsAwaitable]]): + List of fused splits input dist awaitable and the corresponding module names + of each awaitable. + event: Optional[torch.cuda.Event]: Event to record the completion of this stage + index: Optional[int]: Index of the current batch. + version: int = 0; support for backward compatiblity + """ + + # pyre-ignore [4] + input_dist_splits_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) + # pyre-ignore [4] + input_dist_tensors_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) + module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) + module_contexts_next_batch: Dict[str, Multistreamable] = field( + default_factory=dict + ) # deprecated: to support legacy code + fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = ( + field(default_factory=list) + ) + events: List[torch.Event] = field(default_factory=list) + postproc_fwd_results: Dict[str, Any] = field(default_factory=dict) + index: Optional[int] = None + version: int = ( + 0 # 1 is current version, 0 is deprecated but supported for backward compatibility + ) + + +@dataclass +class PrefetchTrainPipelineContext(TrainPipelineContext): + module_input_post_prefetch: Dict[str, Multistreamable] = field(default_factory=dict) + module_contexts_post_prefetch: Dict[str, Multistreamable] = field( + default_factory=dict + ) + module_input_post_prefetch_next_batch: Dict[str, Multistreamable] = field( + default_factory=dict + ) + module_contexts_post_prefetch_next_batch: Dict[str, Multistreamable] = field( + default_factory=dict + ) + + +@dataclass +class EmbeddingTrainPipelineContext(TrainPipelineContext): + embedding_a2a_requests: Dict[ + str, + Union[ + LazyAwaitable[Multistreamable], + # ManagedCollisionEC/EBC returns tuple of awaitables + Tuple[ + LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]] + ], + ], + ] = field(default_factory=dict) + embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) + embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list) + detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) + + +@dataclass +class PipelineStage: + """ + A pipeline stage represents a transform to an input that is independent of the + backwards() of the model. Examples include batch H2D transfer, GPU postproc, or + gradient-less model processing. + + Args: + name (str): Name of the stage. + runnable (Callable[In, Out]): Function that performs a gradient-less + transform. + stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a + unique stream, but having different pipelines share a stream provides more + synchronization semantics. + """ + + name: str + runnable: RunnableType + stream: torch.Stream + fill_callback: Optional[Callable[[], None]] = None + + +@dataclass +class ArgInfo: + """ + Representation of args from a node. + + Attributes: + input_attrs (List[str]): attributes of input batch, + e.g. `batch.attr1.attr2` will produce ["attr1", "attr2"]. + is_getitems (List[bool]): `batch[attr1].attr2` will produce [True, False]. + postproc_modules (List[Optional[PipelinedPostproc]]): list of torch.nn.Modules that + transform the input batch. + constants: constant arguments that are passed to postproc modules. + name (Optional[str]): name for kwarg of pipelined forward() call or None for a + positional arg. + """ + + input_attrs: List[str] + is_getitems: List[bool] + # recursive dataclass as postproc_modules.args -> arginfo.postproc_modules -> so on + postproc_modules: List[Optional["PipelinedPostproc"]] + constants: List[Optional[object]] + name: Optional[str] + + +# pyre-ignore +def _build_args_kwargs( + # pyre-ignore + initial_input: Any, + fwd_args: List[ArgInfo], +) -> Tuple[List[Any], Dict[str, Any]]: + args = [] + kwargs = {} + for arg_info in fwd_args: + if arg_info.input_attrs: + arg = initial_input + for attr, is_getitem, postproc_mod, obj in zip( + arg_info.input_attrs, + arg_info.is_getitems, + arg_info.postproc_modules, + arg_info.constants, + ): + if obj is not None: + if isinstance(obj, list): + arg = [ + ( + v + if not isinstance(v, ArgInfo) + else _build_args_kwargs(initial_input, [v])[0][0] + ) + for v in obj + ] + elif isinstance(obj, dict): + arg = { + k: ( + v + if not isinstance(v, ArgInfo) + else _build_args_kwargs(initial_input, [v])[0][0] + ) + for k, v in obj.items() + } + else: + arg = obj + break + elif postproc_mod is not None: + # postproc will internally run the same logic recursively + # if its args are derived from other postproc modules + # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it + arg = postproc_mod(arg) + else: + if is_getitem: + arg = arg[attr] + elif attr != "": + arg = getattr(arg, attr) + else: + # neither is_getitem nor valid attr, no-op + arg = arg + if arg_info.name: + kwargs[arg_info.name] = arg + else: + args.append(arg) + else: + if arg_info.name: + kwargs[arg_info.name] = None + else: + args.append(None) + return args, kwargs + + +def recursive_record_stream( + # pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any` + res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]], + stream: torch.Stream, +) -> None: + if isinstance(res, torch.Tensor) and res.device.type in ["cuda", "mtia"]: + res.record_stream(stream) + elif isinstance(res, Pipelineable): + res.record_stream(stream) + elif isinstance(res, (list, tuple)): + for v in res: + recursive_record_stream(v, stream) + elif isinstance(res, dict): + for v in res.values(): + recursive_record_stream(v, stream) + + +class NoOpStream: + """No-Op Context manager that takes in a stream""" + + def __init__(self, stream: Optional[torch.Stream]) -> None: + self._stream = stream + + def __enter__(self) -> "NoOpStream": + """Return `self` upon entering the runtime context.""" + return self + + # pyre-ignore + def __exit__(self, exc_type, exc_value, traceback) -> None: + return None + + +class PipelinedPostproc(torch.nn.Module): + """ + Wrapper around postproc module found during model graph traversal for sparse data dist + pipelining. In addition to the original module, it encapsulates information needed for + execution such as list of ArgInfo and the current training pipeline context. + + Args: + postproc_module (torch.nn.Module): postproc module to run + fqn (str): fqn of the postproc module in the model being pipelined + args (List[ArgInfo]): list of ArgInfo for the postproc module + context (TrainPipelineContext): Training context for the next iteration / batch + + Returns: + Any + + Example: + postproc = PipelinedPostproc(postproc_module, fqn, args, context) + # module-swap with pipeliend postproc + setattr(model, fqn, postproc) + """ + + _FORCE_STATE_DICT_LOAD = True + + def __init__( + self, + postproc_module: torch.nn.Module, + fqn: str, + args: List[ArgInfo], + context: TrainPipelineContext, + # TODO: make streams non-optional - skipping now to avoid ripple effect + default_stream: Optional[torch.Stream], + dist_stream: Optional[torch.Stream], + ) -> None: + super().__init__() + self._postproc_module = postproc_module + self._fqn = fqn + self._args = args + self._context = context + self._default_stream = default_stream + self._dist_stream = dist_stream + if not default_stream: + logger.warning( + f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" + ) + if not dist_stream: + logger.warning( + f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" + ) + + if self._dist_stream: + device: torch.device = self._dist_stream.device + # pyre-ignore + self._stream_context = ( + torch.get_device_module(device).stream + if device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + else: + self._stream_context = NoOpStream + + @property + def postproc_module(self) -> torch.nn.Module: + return self._postproc_module + + @property + def fqn(self) -> str: + return self._fqn + + # pyre-ignore + def forward(self, *input, **kwargs) -> Any: + """ + Args: + Any args and kwargs during model fwd + During _start_data_dist, input[0] contains the current data + Returns: + Any + """ + if self._fqn in self._context.postproc_fwd_results: + # This should only be hit in two cases: + # 1) During model forward + # During model forward, avoid duplicate work + # by returning the cached result from previous + # iteration's _start_data_dist + # 2) During _start_data_dist when postproc module is + # shared by more than one args. e.g. if we have + # postproc_out_a = postproc_a(input) + # postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared + # postproc_out_c = postproc_c(postproc_out_a) <-^ + # When processing postproc_b, we cache value of postproc_a(input) + # so when processing postproc_c, we can reuse postproc_a(input) + res = self._context.postproc_fwd_results[self._fqn] + return res + + # Everything below should only be called during _start_data_dist stage + + # Build up arg and kwargs from recursive call to pass to postproc module + # Arguments to postproc module can be also be a derived product + # of another postproc module call, as long as module is pipelineable + + # Use input[0] as _start_data_dist only passes 1 arg + args, kwargs = _build_args_kwargs(input[0], self._args) + + with record_function(f"## sdd_input_postproc {self._context.index} ##"): + # should be no-op as we call this in dist stream + with self._stream_context(self._dist_stream): + res = self._postproc_module(*args, **kwargs) + + # Ensure postproc modules output is safe to use from default stream later + if self._default_stream and self._dist_stream: + self._default_stream.wait_stream(self._dist_stream) + + if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)): + # Result from module forward might be a complex type such as + # Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]] + # In this case, we need to first iterate over each element of tuple + # and call record_stream on first item as KJT is Pipelineable + # for the second item (Dict), we iterate over the values and call + # record_stream accordingly. + + # pyre-ignore[6] + recursive_record_stream(res, self._default_stream) + elif self._context.index == 0: + logger.warning( + f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" + ) + + with self._stream_context(self._default_stream): + # Cache results, only during _start_data_dist + self._context.postproc_fwd_results[self._fqn] = res + + return res + + @property + def args(self) -> List[ArgInfo]: + return self._args + + def set_context(self, context: TrainPipelineContext) -> None: + self._context = context + + def get_context(self) -> TrainPipelineContext: + return self._context + + def named_modules( + self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, torch.nn.Module]]: + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + # This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one + # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module + yield prefix, self + # Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix + yield from self._postproc_module.named_modules( + memo, prefix, remove_duplicate + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + yield from self._postproc_module.named_parameters( + prefix, + recurse, + remove_duplicate, + ) + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from self._postproc_module.named_buffers( + prefix, recurse, remove_duplicate + ) + + # pyre-ignore [14] + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + # super().state_dict(destination, prefix, keep_vars) + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + self._postproc_module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + return destination + + # pyre-ignore [14] + def load_state_dict( + self, + state_dict: OrderedDict[str, torch.Tensor], + strict: bool = True, + ) -> _IncompatibleKeys: + return self._postproc_module.load_state_dict(state_dict, strict=strict) + + +TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) + +EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor] + + +class BaseForward(Generic[TForwardContext]): + def __init__( + self, + name: str, + args: List[ArgInfo], + module: ShardedModule, + context: TForwardContext, + stream: Optional[torch.Stream] = None, + ) -> None: + self._name = name + self._args = args + self._module = module + self._context = context + self._stream = stream + self._device: torch.device = stream.device if stream else torch.device("cuda") + + @property + def name(self) -> str: + return self._name + + @property + def args(self) -> List[ArgInfo]: + return self._args + + def set_context(self, context: TForwardContext) -> None: + self._context = context + + def get_context(self) -> TForwardContext: + return self._context + + +class PipelinedForward(BaseForward[TrainPipelineContext]): + """ + This pipeline is used in TrainPipelineSparseDist + """ + + # pyre-ignore [2, 24] + def __call__(self, *input, **kwargs) -> Awaitable: + assert ( + self._name in self._context.input_dist_tensors_requests + ), "Invalid PipelinedForward usage, please do not directly call model.forward()" + request = self._context.input_dist_tensors_requests.pop(self._name) + assert isinstance(request, Awaitable) + with record_function("## wait_sparse_data_dist ##"): + # Finish waiting on the dist_stream, + # in case some delayed stream scheduling happens during the wait() call. + with torch.get_device_module(self._device).stream(self._stream): + data = request.wait() + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + ctx = self._context.module_contexts.pop(self._name) + + if self._stream is not None: + torch.get_device_module(self._device).current_stream().wait_stream( + self._stream + ) + cur_stream = torch.get_device_module(self._device).current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + ctx.record_stream(cur_stream) + + return self._module.compute_and_output_dist(ctx, data) + + +class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]): + """ + This pipeline is used in TrainPipelineSemiSync + """ + + def __call__( + self, + # pyre-ignore + *input, + # pyre-ignore + **kwargs, + ) -> Union[ + Awaitable[EmbeddingModuleRetType], + Tuple[ + Awaitable[EmbeddingModuleRetType], Awaitable[Optional[KeyedJaggedTensor]] + ], + ]: + assert ( + self._name in self._context.embedding_a2a_requests + ), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()" + + ctx = self._context.module_contexts.pop(self._name) + cur_stream = torch.get_device_module(self._device).current_stream() + + if self._stream is not None: + torch.get_device_module(self._device).current_stream().wait_stream( + self._stream + ) + ctx.record_stream(cur_stream) + + awaitable = self._context.embedding_a2a_requests.pop(self._name) + # in case of MC modules + is_mc_module: bool = isinstance(awaitable, Iterable) + remapped_kjts: Optional[KeyedJaggedTensor] = None + + if is_mc_module: + embeddings = awaitable[0].wait() + remapped_kjts = awaitable[1].wait() + else: + assert isinstance(awaitable, Awaitable) + embeddings = ( + awaitable.wait() + ) # trigger awaitable manually for type checking + + self.detach_embeddings(embeddings=embeddings, cur_stream=cur_stream) + + if is_mc_module: + return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts)) + else: + return LazyNoWait(embeddings) + + def detach_embeddings( + self, + embeddings: Union[Dict[str, JaggedTensor], KeyedTensor], + cur_stream: torch.Stream, + ) -> None: + """ + detach the grad from embeddings so that the backward/opt of the embeddings + won't be invoked by loss.backward(). Instead, there is a dedicated embedding_backward + call in semi-sync pipeline progress. + """ + tensors = [] + detached_tensors = [] + # in case of EC, embeddings are Dict[str, JaggedTensor] + if isinstance(embeddings, Dict): + for jt in embeddings.values(): + assert isinstance(jt, JaggedTensor) + tensor = jt.values() + detached_tensor = tensor.detach().requires_grad_() + detached_tensor.retain_grad() + jt._values = detached_tensor + tensors.append(tensor) + detached_tensors.append(detached_tensor) + self._context.embedding_tensors.append(tensors) + self._context.embedding_features.append(list(embeddings.keys())) + self._context.detached_embedding_tensors.append(detached_tensors) + else: + # in case of EBC, embeddings are KeyedTensor + assert isinstance(embeddings, KeyedTensor) + embeddings.record_stream(cur_stream) + tensor = embeddings.values() + detached_tensor = tensor.detach().requires_grad_() + detached_tensor.retain_grad() + embeddings._values = detached_tensor + tensors.append(tensor) + detached_tensors.append(detached_tensor) + self._context.embedding_tensors.append(tensors) + """ + KeyedTensor is returned by EmbeddingBagCollections and its variants + KeyedTensor holds dense data from multiple features and .values() + returns a single concatenated dense tensor. To ensure that + context.embedding_tensors[i] has the same length as + context.embedding_features[i], we pass in a list with a single item: + a list containing all the embedding feature names. + """ + self._context.embedding_features.append([list(embeddings.keys())]) + self._context.detached_embedding_tensors.append(detached_tensors) + + +class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): + """ + This pipeline is used in PrefetchTrainPipelineSparseDist + """ + + def __init__( + self, + name: str, + args: List[ArgInfo], + module: ShardedModule, + context: PrefetchTrainPipelineContext, + prefetch_stream: Optional[torch.Stream] = None, + ) -> None: + super().__init__( + name=name, + args=args, + module=module, + context=context, + stream=prefetch_stream, + ) + + # pyre-ignore [2, 24] + def __call__(self, *input, **kwargs) -> Awaitable: + assert ( + self._name in self._context.module_input_post_prefetch + ), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()" + data = self._context.module_input_post_prefetch.pop(self._name) + ctx = self._context.module_contexts_post_prefetch.pop(self._name) + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + if self._stream is not None: + torch.get_device_module(self._device).current_stream().wait_stream( + self._stream + ) + cur_stream = torch.get_device_module(self._device).current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + + ctx.record_stream(cur_stream) + + return self._module.compute_and_output_dist(ctx, data) + + +class KJTAllToAllForward: + def __init__( + self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1 + ) -> None: + self._pg = pg + self._splits = splits + self._stagger = stagger + self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) + + def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta: + with torch.no_grad(): + assert len(input.keys()) == sum(self._splits) + rank = dist.get_rank(self._pg) + local_keys = input.keys()[ + self._splits_cumsum[rank] : self._splits_cumsum[rank + 1] + ] + input_splits = input.dist_splits(self._splits) + device = input.values().device + splits_tensors = [ + torch.tensor(splits, device=device) for splits in input_splits + ] + if not input.variable_stride_per_key(): + splits_tensors.append( + torch.tensor([input.stride()] * self._pg.size(), device=device) + ) + return KJTSplitsAllToAllMeta( + pg=self._pg, + _input=input, + splits=self._splits, + splits_tensors=splits_tensors, + input_splits=input_splits, + input_tensors=input.dist_tensors(), + labels=input.dist_labels(), + keys=local_keys, + device=device, + stagger=self._stagger, + ) + + +class Tracer(torch.fx.Tracer): + """ + Disables proxying buffers during tracing. Ideally, proxying buffers would be + disabled, but some models are currently mutating buffer values, which causes errors + during tracing. If those models can be rewritten to not do that, we can likely + remove this line. + """ + + proxy_buffer_attributes = False + + def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + super().__init__() + self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + if ( + isinstance(m, ShardedModule) + or module_qualified_name in self._leaf_modules + or isinstance(m, FSDP) + or isinstance(m, FSDP2) + ): + return True + return super().is_leaf_module(m, module_qualified_name) + + +def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + assert isinstance( + batch, (torch.Tensor, Pipelineable) + ), f"{type(batch)} must implement Pipelineable interface" + return cast(In, batch.to(device=device, non_blocking=non_blocking)) + + +def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None: + """ + As mentioned in + https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch + uses the "caching allocator" for memory allocation for tensors. When a tensor is + freed, its memory is likely to be reused by newly constructed tenosrs. By default, + this allocator traces whether a tensor is still in use by only the CUDA stream where + it was created. When a tensor is used by additional CUDA streams, we need to call + `record_stream` to tell the allocator about these streams. Otherwise, the allocator + might free the underlying memory of the tensor once it is no longer used by the + creator stream. This is a notable programming trick when we write programs using + multiple CUDA streams. + """ + if stream is None: + return + + device = stream.device + torch.get_device_module(device).current_stream().wait_stream(stream) + cur_stream = torch.get_device_module(device).current_stream() + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(cur_stream) + + +def _wait_for_events( + batch: In, + context: TrainPipelineContext, + stream: Optional[torch.Stream], +) -> None: + """ + Wait for any outstanding events for a given context + """ + + for event in context.events: + event.wait() + context.events.clear() + if stream: + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(stream) + + +def _start_data_dist( + pipelined_modules: List[ShardedModule], + batch: Pipelineable, + context: TrainPipelineContext, +) -> None: + if context.version == 0: + context.input_dist_splits_requests.clear() + context.module_contexts_next_batch.clear() + context.fused_splits_awaitables.clear() + + for module in pipelined_modules: + forward = module.forward + assert isinstance( + forward, + ( + PipelinedForward, + PrefetchPipelinedForward, + EmbeddingPipelinedForward, + ), + ) + + # Retrieve argument for the input_dist of EBC + # is_getitem True means this argument could be retrieved by a list + # False means this argument is getting while getattr + # and this info was done in the _rewrite_model by tracing the + # entire model to get the arg_info_list + args, kwargs = _build_args_kwargs(batch, forward.args) + + # Start input distribution. + module_ctx = module.create_context() + if context.version == 0: + context.module_contexts_next_batch[forward.name] = module_ctx + else: + context.module_contexts[forward.name] = module_ctx + context.input_dist_splits_requests[forward.name] = module.input_dist( + module_ctx, *args, **kwargs + ) + _fuse_input_dist_splits(context) + + +def _start_embedding_lookup( + module: ShardedModule, + context: EmbeddingTrainPipelineContext, + source_stream: Optional[torch.Stream], + target_stream: Optional[torch.Stream], + # pyre-ignore[2] + stream_context: Callable[..., AbstractContextManager[Any, Any]], +) -> None: + module_context = context.module_contexts[module.forward.name] + with stream_context(source_stream): + kjt = context.input_dist_tensors_requests[module.forward.name].wait() + + if target_stream is not None: + kjt.record_stream(target_stream) + module_context.record_stream(target_stream) + output_dist_out = module.compute_and_output_dist(module_context, kjt) + context.embedding_a2a_requests[module.forward.name] = output_dist_out + + +def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: + names_per_pg = defaultdict(list) + for name, request in context.input_dist_splits_requests.items(): + pg = None + if isinstance(request, KJTListSplitsAwaitable): + for awaitable in request.awaitables: + if isinstance(awaitable, KJTSplitsAllToAllMeta): + pg = awaitable.pg + break + names_per_pg[pg].append(name) + + for pg, names in names_per_pg.items(): + context.fused_splits_awaitables.append( + ( + names, + FusedKJTListSplitsAwaitable( + # pyre-ignore[6] + requests=[ + context.input_dist_splits_requests[name] for name in names + ], + contexts=[ + ( + context.module_contexts_next_batch[name] + if context.version == 0 + else context.module_contexts[name] + ) + for name in names + ], + pg=pg, + ), + ) + ) + + +def _check_args_for_call_module( + node: torch.fx.Node, +) -> bool: + """ + Recursively checks if args to a node is the result of a call_module. + """ + if node.op == "call_module": + return True + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and _check_args_for_call_module(arg): + return True + + return False + + +def _check_postproc_pipelineable( + module: torch.nn.Module, +) -> bool: + for _, _ in module.named_parameters(recurse=True): + # Cannot have any trainable params for it to be pipelined + logger.warning( + f"Module {module} cannot be pipelined as it has trainable parameters" + ) + return False + return True + + +def _find_postproc_module_recursive( + module: torch.nn.Module, + postproc_module_fqn: str, +) -> Optional[torch.nn.Module]: + """ + Finds the postproc module in the model. + """ + for name, child in module.named_modules(): + if name == postproc_module_fqn: + return child + return None + + +def _swap_postproc_module_recursive( + module: torch.nn.Module, + to_swap_module: torch.nn.Module, + postproc_module_fqn: str, + path: str = "", +) -> torch.nn.Module: + """ + Swaps the postproc module in the model. + """ + if isinstance(module, PipelinedPostproc): + return module + + if path == postproc_module_fqn: + return to_swap_module + + for name, child in module.named_children(): + child = _swap_postproc_module_recursive( + child, + to_swap_module, + postproc_module_fqn, + path + "." + name if path else name, + ) + setattr(module, name, child) + + return module + + +def _get_node_args_helper_inner( + model: torch.nn.Module, + # pyre-ignore + arg, + arg_info: ArgInfo, + num_found: int, + pipelined_postprocs: Set[PipelinedPostproc], + context: TrainPipelineContext, + pipeline_postproc: bool, + for_postproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, +) -> int: + num_found = 0 + while True: + if not isinstance(arg, torch.fx.Node): + if pipeline_postproc: + arg_info.input_attrs.insert(0, "") + arg_info.is_getitems.insert(0, False) + arg_info.postproc_modules.insert(0, None) + + if isinstance(arg, fx_immutable_dict): + fx_nested_dict = {} + + for k, v in arg.items(): + if isinstance(v, torch.fx.Node): + arg_info_nested = ArgInfo([], [], [], [], None) + _get_node_args_helper_inner( + model, + v, + arg_info_nested, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + fx_nested_dict[k] = arg_info_nested + else: + fx_nested_dict[k] = v + + arg_info.constants.insert(0, fx_nested_dict) + elif isinstance(arg, fx_immutable_list): + fx_nested_list = [] + for v in arg: + if isinstance(v, torch.fx.Node): + arg_info_nested = ArgInfo([], [], [], [], None) + _get_node_args_helper_inner( + model, + v, + arg_info_nested, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + fx_nested_list.append(arg_info_nested) + else: + fx_nested_list.append(v) + + arg_info.constants.insert(0, fx_nested_list) + else: + arg_info.constants.insert(0, arg) + num_found += 1 + break + child_node = arg + + if child_node.op == "placeholder": + if hasattr(child_node, "ph_key"): + # pyre-ignore[16] + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_key = ph_key.replace("[", ".") + ph_keys = ph_key.split(".") + for key in ph_keys: + if "]" in key: + arg_info.input_attrs.append(key[:-1]) + arg_info.is_getitems.append(True) + else: + arg_info.input_attrs.append(key) + arg_info.is_getitems.append(False) + arg_info.postproc_modules.append(None) + arg_info.constants.append(None) + else: + # no-op + arg_info.input_attrs.insert(0, "") + arg_info.is_getitems.insert(0, False) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + + num_found += 1 + break + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "builtins" + # pyre-ignore[16] + and child_node.target.__name__ == "getattr" + ): + # pyre-fixme[6]: For 2nd argument expected `str` but got + # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, + # complex, float, int, range, slice, str, device, dtype, layout, + # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, False) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "_operator" + # pyre-ignore[16] + and child_node.target.__name__ == "getitem" + ): + # pyre-fixme[6]: For 2nd argument expected `str` but got + # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, + # complex, float, int, range, slice, str, device, dtype, layout, + # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, True) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torch.utils._pytree" + # pyre-ignore[16] + and child_node.target.__name__ == "tree_unflatten" + ): + """ + This is for the PT2 export path where we unflatten the input to reconstruct + the structure with the recorded tree spec. + """ + assert arg_info.is_getitems[0] + # pyre-fixme[16] + arg = child_node.args[0][arg_info.input_attrs[0]] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" + # pyre-fixme[16] + and child_node.target.__name__ == "KeyedJaggedTensor" + ): + call_module_found = False + + for arg_node in chain(child_node.args, child_node.kwargs.values()): + if isinstance(arg_node, torch.fx.Node) and _check_args_for_call_module( + arg_node + ): + call_module_found = True + break + + if call_module_found: + break + + if "values" in child_node.kwargs: + arg = child_node.kwargs["values"] + else: + arg = child_node.args[1] + elif child_node.op == "call_method" and child_node.target == "get": + # pyre-ignore[6] + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, True) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif child_node.op == "call_module": + postproc_module_fqn = str(child_node.target) + postproc_module = _find_postproc_module_recursive( + model, postproc_module_fqn + ) + + if not pipeline_postproc: + logger.warning( + f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" + ) + break + + if not postproc_module: + # Could not find such module, should not happen + break + + if isinstance(postproc_module, PipelinedPostproc): + # Already did module swap and registered args, early exit + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_postprocs.add(postproc_module) + arg_info.postproc_modules.insert(0, postproc_module) + arg_info.constants.insert(0, None) + num_found += 1 + break + + if not isinstance(postproc_module, torch.nn.Module): + logger.warning( + f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" + ) + break + + # check if module is safe to pipeline i.e.no trainable param + if not _check_postproc_pipelineable(postproc_module): + break + + # For module calls, `self` isn't counted + total_num_args = len(child_node.args) + len(child_node.kwargs) + if total_num_args == 0: + # module call without any args, assume KJT modified + break + + # recursive call to check that all inputs to this postproc module + # is either made of postproc module or non-modifying train batch input + # transformations + postproc_args, num_found_safe_postproc_args = _get_node_args( + model, + child_node, + pipelined_postprocs, + context, + pipeline_postproc, + True, + default_stream=default_stream, + dist_stream=dist_stream, + ) + if num_found_safe_postproc_args == total_num_args: + logger.info( + f"""Module {postproc_module} is a valid postproc module (no + trainable params and inputs can be derived from train batch input + via a series of either valid postproc modules or non-modifying + transformations) and will be applied during sparse data dist + stage""" + ) + + pipelined_postproc_module = PipelinedPostproc( + postproc_module, + postproc_module_fqn, + postproc_args, + context, + default_stream=default_stream, + dist_stream=dist_stream, + ) + + # module swap + _swap_postproc_module_recursive( + model, pipelined_postproc_module, postproc_module_fqn + ) + + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_postprocs.add(pipelined_postproc_module) + arg_info.postproc_modules.insert(0, pipelined_postproc_module) + arg_info.constants.insert(0, None) + + num_found += 1 + + # we cannot set any other `arg` value here + # break to avoid infinite loop + break + else: + break + + return num_found + + +def _get_node_args_helper( + model: torch.nn.Module, + # pyre-ignore + arguments, + num_found: int, + pipelined_postprocs: Set[PipelinedPostproc], + context: TrainPipelineContext, + pipeline_postproc: bool, + # Add `None` constants to arg info only for postproc modules + # Defaults to False for backward compatibility + for_postproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, +) -> Tuple[List[ArgInfo], int]: + """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + """ + arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] + for arg, arg_info in zip(arguments, arg_info_list): + if not for_postproc_module and arg is None: + num_found += 1 + continue + num_found += _get_node_args_helper_inner( + model, + arg, + arg_info, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + return arg_info_list, num_found + + +def _get_node_args( + model: torch.nn.Module, + node: Node, + pipelined_postprocs: Set[PipelinedPostproc], + context: TrainPipelineContext, + pipeline_postproc: bool, + for_postproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, +) -> Tuple[List[ArgInfo], int]: + num_found = 0 + + pos_arg_info_list, num_found = _get_node_args_helper( + model, + node.args, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + kwargs_arg_info_list, num_found = _get_node_args_helper( + model, + node.kwargs.values(), + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + + # Replace with proper names for kwargs + for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): + arg_info_list.name = name + + arg_info_list = pos_arg_info_list + kwargs_arg_info_list + + return (arg_info_list, num_found) + + +def _get_leaf_module_names_helper( + model: torch.nn.Module, + path: str, + leaf_module_names: Set[str], +) -> bool: + sharded_children = set() + for name, child in model.named_children(): + curr_path = path + name + if isinstance(child, ShardedModule): + sharded_children.add(name) + else: + child_sharded = _get_leaf_module_names_helper( + child, + curr_path + ".", + leaf_module_names, + ) + if child_sharded: + sharded_children.add(name) + + if len(sharded_children) > 0: + for name, child in model.named_children(): + if name in sharded_children: + continue + # assume module is leaf node unless annotated otherwise + if not getattr(child, "_is_pytorch_fx_traceable", False): + leaf_module_names.add(path + name) + return len(sharded_children) > 0 + + +def _get_leaf_module_names(model: torch.nn.Module) -> List[str]: + """ + Returns a list of top level modules to be used as leaf modules for FX tracing. + This is a shallow FX trace that only goes the minimum depth required to pipeline + the model unless child modules are explicitly tagged as `_is_pytorch_fx_traceable`. + """ + + leaf_module_names: Set[str] = set() + _get_leaf_module_names_helper( + model, + "", + leaf_module_names, + ) + return list(leaf_module_names) + + +def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> bool: + sharded_children = set() + for name, child in module.named_children(): + curr_path = path + name + if isinstance(child, ShardedModule): + sharded_children.add(name) + else: + child_sharded = _jit_modules(child, curr_path + ".", optional) + if child_sharded: + sharded_children.add(name) + + if len(sharded_children) > 0: + for name, child in module.named_children(): + if name not in sharded_children: + try: + jit_child = torch.jit.script(child) + setattr(module, name, jit_child) + logger.info(f"jit.script applied to {path + name}.") + except Exception as error: + if not optional: + raise + else: + logger.info( + f"Warning: failed to jit.script {path + name}: {error}." + ) + + return len(sharded_children) > 0 + + +def _pipeline_detach_model( + model: torch.nn.Module, + pipelined_modules: List[ShardedModule], + # pyre-ignore[2] + original_forwards: List[Callable[..., Any]], + original_kjt_dist_forwards: List[ + Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] + ], + pipelined_postprocs: List[PipelinedPostproc], +) -> None: + # Replace pipelined module forward and input dist forward with original forward + kjt_dists = [] + for mod, original_fwd in zip(pipelined_modules, original_forwards): + # pyre-ignore + mod.forward = original_fwd + + for _, child_module in mod.named_modules(): + if not hasattr(child_module, "_input_dists"): + continue + for input_dist in child_module._input_dists: + if hasattr(input_dist, "_dist"): + kjt_dists.append(input_dist._dist) + assert len(kjt_dists) == len( + original_kjt_dist_forwards + ), f"Number of KJT dists ({len(kjt_dists)}) does not match number of kjt dist forwards provided ({len(original_kjt_dist_forwards)})" + + for kjt_dist, original_kjt_dist_fwd in zip( + kjt_dists, + original_kjt_dist_forwards, + ): + kjt_dist.forward = original_kjt_dist_fwd + + # Get underlying nn.Module + if isinstance(model, DistributedModelParallel): + model = model.module + + # Replace pipelined postproc modules with original postproc modules + for postproc_mod in pipelined_postprocs: + setattr(model, postproc_mod.fqn, postproc_mod.postproc_module) + + +# pyre-ignore[3] +def _rewrite_model( # noqa C901 + model: torch.nn.Module, + context: TForwardContext, + dist_stream: Optional[torch.Stream], + batch: Optional[In] = None, + apply_jit: bool = False, + pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward, + pipeline_postproc: bool = False, + default_stream: Optional[torch.Stream] = None, +) -> Tuple[ + List[ShardedModule], + torch.nn.Module, + List[Callable[..., Any]], + List[PipelinedPostproc], + List[str], +]: + input_model = model + # Get underlying nn.Module + if isinstance(model, DistributedModelParallel): + model = model.module + + # Collect a list of sharded modules. + sharded_modules = {} + for name, m in model.named_modules(): + if isinstance(m, ShardedModule): + sharded_modules[name] = m + + # Trace a model. + concrete_args = {} + if batch: + if hasattr(batch, "to_proxy"): + # for some special models, it requires using "input" + # as the key for input + # pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy. + concrete_args["inputs"] = copy.copy(batch).to_proxy() + elif hasattr(batch, "to_proxy_tuple"): + # when the model is pre-fx traced or dynamo exported, the + # inputs are already flattened, and therefore we use + # tuple as concrete args that fx.trace will automatically + # match with the argument names. + # We pass in the model for the caller side to customize + # the batch + # pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy_tuple. + concrete_args = batch.to_proxy_tuple(model) + + tracer = Tracer(leaf_modules=_get_leaf_module_names(model)) + graph = tracer.trace(model, concrete_args=concrete_args) + + # Select sharded modules, which are top-level in the forward call graph, + # i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'. + pipelined_forwards = [] + original_forwards = [] + + pipelined_postprocs: Set[PipelinedPostproc] = set() + non_pipelined_sharded_modules = [] + + for node in graph.nodes: + if node.op == "call_module" and node.target in sharded_modules: + total_num_args = len(node.args) + len(node.kwargs) + if total_num_args == 0: + continue + arg_info_list, num_found = _get_node_args( + model, + node, + pipelined_postprocs, + context, + pipeline_postproc, + default_stream=default_stream, + dist_stream=dist_stream, + ) + + if num_found == total_num_args: + logger.info(f"Module '{node.target}' will be pipelined") + child = sharded_modules[node.target] + original_forwards.append(child.forward) + child.forward = pipelined_forward( + node.target, + arg_info_list, + child, + context, + dist_stream, + ) + pipelined_forwards.append(child) + else: + logger.warning( + f"Module '{node.target}'' will not be pipelined, due to input modifications" + ) + non_pipelined_sharded_modules.append(node.target) + + # JIT script unsharded modules if applicable. + if apply_jit: + graph_model = torch.fx.GraphModule(model, graph) + _jit_modules(graph_model, "") + if isinstance(input_model, DistributedModelParallel): + input_model.module = graph_model + + if non_pipelined_sharded_modules: + logger.warn( + "Sharded modules were not pipelined: %s. " + + "This should be fixed for pipelining to work to the full extent.", + ", ".join(non_pipelined_sharded_modules), + ) + + return ( + pipelined_forwards, + input_model, + original_forwards, + list(pipelined_postprocs), + non_pipelined_sharded_modules, + ) + + +def _override_input_dist_forwards( + pipelined_modules: List[ShardedModule], +) -> List[Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]]: + """ + Overrides each input dist forward to support fusing the splits collective. + NOTE: this can only be called after the input dists are initialized. + """ + original_kjt_dist_forwards = [] + for module in pipelined_modules: + for child_fqn, child_module in module.named_modules(): + if hasattr(child_module, "_has_uninitialized_input_dist"): + assert ( + not child_module._has_uninitialized_input_dist + ), f"{child_fqn} has uninitialized input dist" + + if not hasattr(child_module, "_input_dists"): + continue + + for input_dist in child_module._input_dists: + if hasattr(input_dist, "_dist"): + assert isinstance(input_dist._dist, KJTAllToAll) + original_kjt_dist_forwards.append(input_dist._dist.forward) + input_dist._dist.forward = KJTAllToAllForward( + pg=input_dist._dist._pg, + splits=input_dist._dist._splits, + stagger=input_dist._dist._stagger, + ) + return original_kjt_dist_forwards + + +def get_h2d_func(batch: In, device: torch.device) -> Pipelineable: + return batch.to(device, non_blocking=True) + + +class DataLoadingThread(Thread, Generic[In]): + def __init__( + self, + device: torch.device, + dataloader_iter: Iterator[In], + to_device_non_blocking: bool, + memcpy_stream_priority: int = 0, + memcpy_stream: Optional[torch.Stream] = None, + ) -> None: + super().__init__(name="DataLoadingThread") + self._stop: bool = False + self.daemon = True # Mark as daemon thread so that Python will not wait for it at shutdown. + self._dataloader_iter = dataloader_iter + self._buffer_empty_event: Event = Event() + self._buffer_filled_event: Event = Event() + if memcpy_stream is None: + self._memcpy_stream: Optional[torch.Stream] = ( + torch.get_device_module(device).Stream(priority=memcpy_stream_priority) + if device.type in ["cuda", "mtia"] + else None + ) + else: + self._memcpy_stream = memcpy_stream + self._device = device + self._to_device_non_blocking = to_device_non_blocking + self._buffered: Optional[In] = None + self._buffer_empty_event.set() + + def run(self) -> None: + if self._device.type == "cuda" and torch.cuda.is_available(): + # set the current device the same as the one used in the main thread + torch.cuda.set_device(self._device) + elif self._device.type == "mtia" and torch.mtia.is_available(): + # set the current device the same as the one used in the main thread + torch.mtia.set_device(self._device) + + while not self._stop: + self._buffer_empty_event.wait() + # Set the filled event to unblock progress() and return. + if self._stop: + self._buffer_filled_event.set() + return + with record_function("## load_batch ##"): + try: + batch = next(self._dataloader_iter) + except StopIteration: + self._stop = True + self._buffer_filled_event.set() + return + with record_function("## copy_batch_to_gpu ##"): + with torch.get_device_module(self._device).stream(self._memcpy_stream): + self._buffered = cast( + In, + batch.to( + self._device, non_blocking=self._to_device_non_blocking + ), + ) + self._buffer_empty_event.clear() + self._buffer_filled_event.set() + + def stop(self) -> None: + logger.info("Stopping data loading thread...") + self._stop = True + # Unblock any thread that are waiting for these events. + self._buffer_filled_event.set() + self._buffer_empty_event.set() + logger.info("Data loading thread stopped.") + + def get_next_batch(self, none_throws: bool = False) -> Optional[In]: + """ + Get the next batch from the buffer if threading is enabled, otherwise + call load_next_batch directly. + + This function is not thread safe. We assume this is only invoked from + the main thread in the training loop. + """ + self._buffer_filled_event.wait() + batch = self._buffered + if batch is None: + if none_throws: + raise StopIteration + return None + self._buffered = None + self._buffer_filled_event.clear() + self._buffer_empty_event.set() + return batch + + +def _prefetch_embeddings( + batch: In, + context: PrefetchTrainPipelineContext, + pipelined_modules: List[ShardedModule], + device: torch.device, + stream_context: torch.Stream, + data_dist_stream: Optional[torch.Stream], + default_stream: Optional[torch.Stream], +) -> Dict[str, KJTList]: + data_per_sharded_module = {} + for sharded_module in pipelined_modules: + forward = sharded_module.forward + assert isinstance(forward, PrefetchPipelinedForward) + + assert forward._name in context.input_dist_tensors_requests + request = context.input_dist_tensors_requests.pop(forward._name) + assert isinstance(request, Awaitable) + with record_function("## wait_sparse_data_dist ##"): + # Finish waiting on the dist_stream, + # in case some delayed stream scheduling happens during the wait() call. + with stream_context(data_dist_stream): + data = request.wait() + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + module_context = context.module_contexts[forward._name] + if data_dist_stream is not None: + torch.get_device_module(device).current_stream().wait_stream( + data_dist_stream + ) + cur_stream = torch.get_device_module(device).current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + data.record_stream(default_stream) + + module_context.record_stream(cur_stream) + module_context.record_stream(default_stream) + + sharded_module.prefetch( + ctx=module_context, + dist_input=data, + forward_stream=default_stream, + ) + data_per_sharded_module[forward._name] = data + return data_per_sharded_module + + +class SparseDataDistUtil(Generic[In]): + """ + Helper class exposing methods for sparse data dist and prefetch pipelining. + Currently used for `StagedTrainPipeline` pipeline stages + + Args: + model (torch.nn.Module): Model to pipeline + data_dist_stream (torch.cuda.Stream): Stream on which to run sparse data dist. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs + Defaults to `None`. This needs to be passed in to enable prefetch pipelining. + + Example:: + sdd = SparseDataDistUtil( + model=model, + data_dist_stream=torch.cuda.Stream(), + prefetch_stream=torch.cuda.Stream(), <-- required to enable prefetch pipeline + ) + pipeline = [ + PipelineStage( + name="data_copy", + runnable=lambda batch, context: batch.to( + self._device, non_blocking=True + ), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.data_dist_stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + PipelineStage( + name="prefetch", + runnable=sdd.prefetch, + stream=sdd.prefetch_stream, + fill_callback=sdd.load_prefetch, + ), + ] + + return StagedTrainPipeline(pipeline_stages=pipeline) + """ + + def __init__( + self, + model: torch.nn.Module, + data_dist_stream: torch.Stream, + apply_jit: bool = False, + prefetch_stream: Optional[torch.Stream] = None, + ) -> None: + super().__init__() + self.model = model + self.data_dist_stream = data_dist_stream + self.prefetch_stream = prefetch_stream + self.apply_jit = apply_jit + self.context = ( + PrefetchTrainPipelineContext(version=0) + if prefetch_stream + else TrainPipelineContext(version=0) + ) + self.initialized = False + self._pipelined_modules: List[ShardedModule] = [] + self._pipelined_postprocs: List[PipelinedPostproc] = [] + # pyre-ignore + self.fwd_hook = None + self._device: torch.device = data_dist_stream.device + + # pyre-ignore + self._original_forwards: List[Callable[..., Any]] = [] + self._original_kjt_dist_forwards: List[ + Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] + ] = [] + + self._pipelined_forward: Type[BaseForward[TrainPipelineContext]] = cast( + Type[BaseForward[TrainPipelineContext]], + (PrefetchPipelinedForward if prefetch_stream else PipelinedForward), + ) + + self._default_stream: Optional[torch.Stream] = ( + (torch.get_device_module(self._device).Stream()) + if self._device.type in ["cuda", "mtia"] + else None + ) + + def detach(self) -> torch.nn.Module: + """ + Removes sparse data dist (SDD) pipelining from model forward and input dist. + Modifies existing model in place and returns the model. + + detach() can be called at any point, and inflight batches do not need to be + flushed before calling it. Calling pipeline.progress() will re-attach the model + to the pipeline and the pipeline will progress normally from the point it was detached (i.e. inflight batches will be kept when calling detach). + + While the model is detached, it is equivalent to the model before passing to + the pipeline, so forward and backward passes, and optimizer updates can be + carried out normally. + """ + if self.initialized: + assert self.fwd_hook is not None + self.fwd_hook.remove() + + _pipeline_detach_model( + model=self.model, + pipelined_modules=self._pipelined_modules, + original_forwards=self._original_forwards, + original_kjt_dist_forwards=self._original_kjt_dist_forwards, + pipelined_postprocs=self._pipelined_postprocs, + ) + + self.initialized = False + return self.model + + def start_sparse_data_dist(self, batch: In) -> In: + if not self.initialized: + # Step 1: Pipeline input dist in trec sharded modules + # TODO (yhshin): support postproc modules for `StagedTrainPipeline` + ( + self._pipelined_modules, + self.model, + self._original_forwards, + self._pipelined_postprocs, + _, + ) = _rewrite_model( + model=self.model, + context=self.context, + dist_stream=self.data_dist_stream, + batch=batch, + apply_jit=self.apply_jit, + pipelined_forward=self._pipelined_forward, + default_stream=self._default_stream, + ) + # initializes input dist, so we can override input dist forwards + _start_data_dist(self._pipelined_modules, batch, self.context) + self._original_kjt_dist_forwards = _override_input_dist_forwards( + self._pipelined_modules + ) + + # Step 2: Register post-forward hook to wait SDD + def forward_hook( + module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor]], + output: Union[torch.Tensor, Tuple[torch.Tensor]], + ) -> None: + if self.prefetch_stream is not None: + # Need to load prefetch before wait_sparse_data_dist + self.load_prefetch() + self.wait_sparse_data_dist() + + self.fwd_hook = self.model.register_forward_hook(forward_hook) + + self.initialized = True + + _start_data_dist(self._pipelined_modules, batch, self.context) + + return batch + + def wait_sparse_data_dist(self) -> None: + with record_function("## wait_sparse_data_dist ##"): + with torch.get_device_module(self._device).stream(self.data_dist_stream): + self.context.module_contexts = ( + self.context.module_contexts_next_batch.copy() + ) + self.context.input_dist_tensors_requests.clear() + for names, awaitable in self.context.fused_splits_awaitables: + for name, request in zip(names, awaitable.wait()): + self.context.input_dist_tensors_requests[name] = request + + def prefetch(self, batch: In) -> In: + """ + Waits for input dist to finish, then prefetches data. + """ + assert isinstance( + self.context, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use prefetch() as a stage" + self.context.module_input_post_prefetch_next_batch.clear() + # pyre-ignore + self.context.module_contexts_post_prefetch_next_batch.clear() + + data_per_pipelined_module = _prefetch_embeddings( + batch, + # pyre-ignore + self.context, + self._pipelined_modules, + self._device, + torch.get_device_module(self._device).stream, + self.data_dist_stream, + self._default_stream, + ) + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + data = data_per_pipelined_module[forward._name] + # pyre-ignore [16] + self.context.module_input_post_prefetch_next_batch[forward._name] = data + self.context.module_contexts_post_prefetch_next_batch[forward._name] = ( + self.context.module_contexts.pop(forward._name) + ) + return batch + + def load_prefetch(self) -> None: + assert isinstance( + self.context, PrefetchTrainPipelineContext + ), "Pass prefetch_stream into SparseDataDistUtil to use load_prefetch()" + self.context.module_input_post_prefetch.clear() + # pyre-ignore + self.context.module_contexts_post_prefetch.clear() + + with record_function("## load_sharded_module_prefetch ##"): + with torch.get_device_module(self._device).stream(self.prefetch_stream): + for sharded_module in self._pipelined_modules: + forward = sharded_module.forward + assert isinstance(forward, PrefetchPipelinedForward) + self.context.module_input_post_prefetch[forward._name] = ( + self.context.module_input_post_prefetch_next_batch[ + forward._name + ] + ) + self.context.module_contexts_post_prefetch[forward._name] = ( + self.context.module_contexts_post_prefetch_next_batch[ + forward._name + ] + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index c71260f34..ac7260d25 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -5,21 +5,46 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc import operator from dataclasses import dataclass from enum import Enum, unique -from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from fbgemm_gpu.runtime_monitor import TBEStatsReporterConfig +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, + CacheAlgorithm, + MultiPassPrefetchConfig, +) from torch.autograd.profiler import record_function -from torchrec.types import ModuleNoCopyMixin +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.distributed_c10d import _get_pg_default_device +from torchrec.tensor_types import UInt2Tensor, UInt4Tensor +from torchrec.types import DataType, ModuleNoCopyMixin try: # For python 3.6 and below, GenericMeta will be used by # other metaclasses (i.e. AwaitableMeta) for customized # behaviors, as Generic is non-trival metaclass in # python 3.6 and below - from typing import GenericMeta # pyre-ignore: python 3.6 + from typing import GenericMeta except ImportError: # In python 3.7+, GenericMeta doesn't exist as it's no # longer a non-trival metaclass, @@ -51,10 +76,45 @@ class GenericMeta(type): ShardingSpec, ShardMetadata, ) -from torch.nn.modules.module import _addindent from torchrec.streamable import Multistreamable +def _tabulate( + table: List[List[Union[str, int]]], headers: Optional[List[str]] = None +) -> str: + """ + Format a table as a string. + Parameters: + table (list of lists or list of tuples): The data to be formatted as a table. + headers (list of strings, optional): The column headers for the table. If not provided, the first row of the table will be used as the headers. + Returns: + str: A string representation of the table. + """ + if headers is None: + headers = table[0] + table = table[1:] + headers = cast(List[str], headers) + rows = [] + # Determine the maximum width of each column + col_widths = [max([len(str(item)) for item in column]) for column in zip(*table)] + col_widths = [max(i, len(j)) for i, j in zip(col_widths, headers)] + # Format each row of the table + for row in table: + row_str = " | ".join( + [str(item).ljust(width) for item, width in zip(row, col_widths)] + ) + rows.append(row_str) + # Add the header row and the separator line + rows.insert( + 0, + " | ".join( + [header.center(width) for header, width in zip(headers, col_widths)] + ), + ) + rows.insert(1, " | ".join(["-" * width for width in col_widths])) + return "\n".join(rows) + + class ShardingType(Enum): """ Well-known sharding types, used by inter-module optimizations. @@ -74,6 +134,35 @@ class ShardingType(Enum): TABLE_ROW_WISE = "table_row_wise" # Column-wise on the same node and table-wise across nodes TABLE_COLUMN_WISE = "table_column_wise" + # Grid sharding, where each rank gets a subset of columns and rows in a CW and TWRW style + GRID_SHARD = "grid_shard" + + +class EmbeddingEvent(Enum): + """ + Events in sharded embedding module's forward, used for trace annotations + """ + + KJT_SPLITS_DIST = "splits_dist" + KJT_TENSORS_DIST = "tensors_dist" + LOOKUP = "lookup" + OUTPUT_DIST = "output_dist" + # When .wait() is called on output_dist awaitable + # Useful for linking backward comms event in trace to forward event + OUTPUT_DIST_WAIT = "output_dist_wait" + + +class PipelineType(Enum): + """ + Known pipeline types. + Check out //torchrec/distributed/train_pipeline/train_pipelines.py + for details about pipelines. + """ + + NONE = "none" + TRAIN_BASE = "train_base" + TRAIN_SPARSE_DIST = "train_sparse_dist" + TRAIN_PREFETCH_SPARSE_DIST = "train_prefetch_sparse_dist" class ParameterStorage(Enum): @@ -117,14 +206,13 @@ class QuantizedCommCodec(Generic[QuantizationContext]): def encode( self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def decode( self, input_grad: torch.Tensor, ctx: Optional[QuantizationContext] = None - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... + @property def quantized_dtype(self) -> torch.dtype: """ tensor.dtype of the resultant encode(input_tensor) @@ -151,6 +239,18 @@ def create_context(self) -> Optional[QuantizationContext]: """ ... + def padded_size( + self, + input_tensor: torch.Tensor, + dim_per_rank: List[int], + my_rank: int, + qcomm_ctx: QuantizationContext, + ) -> Tuple[int, int]: + """ + Return (padded_dim_sum, padding_size) of the input tensor for quantization. + """ + ... + class NoOpQuantizedCommCodec(Generic[QuantizationContext]): """ @@ -184,6 +284,18 @@ def calc_quantized_size( def create_context(self) -> Optional[QuantizationContext]: return None + def padded_size( + self, + input_tensor: torch.Tensor, + dim_per_rank: List[int], + my_rank: int, + qcomm_ctx: QuantizationContext, + ) -> Tuple[int, int]: + dim_sum = ( + input_tensor.shape[0] if input_tensor.dim() == 1 else input_tensor.shape[1] + ) + return dim_sum, 0 + @dataclass class QuantizedCommCodecs: @@ -238,8 +350,9 @@ def _wait_impl(self) -> W: return self._obj -# pyre-fixme[11]: Annotation `ProxyableClassMeta` is not defined as a type. -class _LazyAwaitableMeta(GenericMeta, abc.ABCMeta, torch.fx.ProxyableClassMeta): +class _LazyAwaitableMeta( + GenericMeta, abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta +): """ The _LazyAwaitableMeta class that inherits both ABCMeta and ProxyableClassMeta This is because ABCMeta/ProxyableClassMeta are both non-trival metaclasses @@ -269,7 +382,7 @@ class LazyAwaitable(Awaitable[W], metaclass=_LazyAwaitableMeta): Some caveats: * This works with Pytorch functions, but not any generic method, if - you would like to do arbitary python operations, you need to + you would like to do arbitrary python operations, you need to implement the corresponding magic methods * In the case that one function have two or more arguments are LazyAwaitable, @@ -298,8 +411,9 @@ def _wait_async(obj: Any) -> Any: else: return obj + @classmethod # pyre-ignore [2, 3] - def __torch_function__(self, func, types, args=(), kwargs=None): + def __torch_function__(cls, func, types, args=(), kwargs=None): """ The LazyAwaitable type has a `__torch_function__` implementation. This means when this type is seens as an argument to a PyTorch @@ -347,6 +461,48 @@ def _wait_impl(self) -> W: return self._obj +KT = TypeVar("KT") +VT_co = TypeVar("VT_co") +ParentW = TypeVar("ParentW") + + +class LazyGetItemMixin(Generic[KT, VT_co]): + """Augments the base LazyAwaitable with a lazy __getitem__ method. + + Instead of triggering a wait() on a __getitem__ call, KeyedLazyAwaitable + will return another awaitable. This can achieve better + communication/computation overlap by deferring the wait() until the + tensor data is actually needed. + + This is intended for Awaitables that model keyed collections, like + dictionaries or EmbeddingBagCollectionAwaitable. + + NOTE: if using this mixin, please include it before LazyAwaitable in the + inheritance list, so that Python MRO can properly select this __getitem__ + implementation. + """ + + def __getitem__(self, key: KT) -> LazyAwaitable[VT_co]: + return GetItemLazyAwaitable(self, key) + + +class GetItemLazyAwaitable(LazyAwaitable[W], Generic[W, ParentW, KT]): + """The LazyAwaitable returned from a __getitem__ call on `LazyGetItemMixin`. + + When the actual value of this awaitable is requested, wait on the parent and + then call __getitem__ on the result. + """ + + def __init__(self, parent: LazyAwaitable[ParentW], key: KT) -> None: + super().__init__() + self._parent = parent + self._key = key + + def _wait_impl(self) -> W: + kt = LazyAwaitable._wait_async(self._parent) + return kt[self._key] + + # install magic methods for orig_method_name in torch.fx.graph.magic_methods: as_magic = f"__{orig_method_name}__" @@ -374,6 +530,7 @@ def impl(*args, **kwargs): # install reflective magic methods for orig_method_name in torch.fx.graph.reflectable_magic_methods: as_magic = f"__r{orig_method_name}__" + # pyre-ignore [2, 3] def scope(method): # pyre-ignore [2, 3, 53] @@ -391,21 +548,167 @@ def impl(self, rhs): scope(orig_method_name) +class ModuleShardingPlan: + pass + + +class CacheStatistics(abc.ABC): + @property + @abc.abstractmethod + def expected_lookups(self) -> float: + """Number of expected cache lookups per training step. + + This is the expected number of distinct values in a global training batch.""" + + @abc.abstractmethod + def expected_miss_rate(self, clf: float) -> float: + """Expected cache lookup miss rate for a given cache size. + + When clf (cache load factor) is 0, returns 1.0 (100% miss). When clf is 1.0, + returns 0 (100% hit). For values of clf between these extremes, returns the + estimated miss rate of the cache, e.g. based on knowledge of the statistical + properties of the training data set.""" + + @property + @abc.abstractmethod + def cacheability(self) -> float: + """Summarized measure of the difficulty to cache a dataset that is independent of + cache size. A score of 0 means the dataset is very cacheable (e.g. high locality + between accesses), a score of 1 is very difficult to cache.""" + + +@dataclass +class CacheParams: + """Caching related fused params for an embedding table. Most of these are + passed to FBGEMM's Split TBE. These are useful for when uvm caching is used. + + Attributes: + algorithm (Optional[CacheAlgorithm]): cache algorithm to use. Options + include LRU and LFU. + load_factor (Optional[float]): cache load factor per table. This decides + the size of the cache space for the table, and is crucial for + performance when using uvm caching. + reserved_memory (Optional[float]): reserved memory for the cache. + precision (Optional[DataType]): precision of the cache. Ideally this + should be the same as the data type of the weights (aka table). + prefetch_pipeline (Optional[bool]): whether to prefetch pipeline is + used. + stats (Optional[CacheStatistics]): cache statistics which has table + related metadata. Used to create a better plan and tune the load + factor. + """ + + algorithm: Optional[CacheAlgorithm] = None + load_factor: Optional[float] = None + reserved_memory: Optional[float] = None + precision: Optional[DataType] = None + prefetch_pipeline: Optional[bool] = None + stats: Optional[CacheStatistics] = None + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None + + def __hash__(self) -> int: + return hash( + ( + self.algorithm, + self.load_factor, + self.reserved_memory, + self.precision, + self.prefetch_pipeline, + self.multipass_prefetch_config, + ) + ) + + +@dataclass +class KeyValueParams: + """ + Params for SSD TBE aka SSDTableBatchedEmbeddingBags. + + Attributes: + ssd_storage_directory (Optional[str]): Directory for SSD. If we want directory + to be f"data00_nvidia{local_rank}", pass in "data00_nvidia@local_rank". + ssd_rocksdb_write_buffer_size: Optional[int]: rocksdb write buffer size per tbe, + relavant to rocksdb compaction frequency + ssd_rocksdb_shards: Optional[int]: rocksdb shards number + gather_ssd_cache_stats: bool: whether enable ssd stats collection, std reporter and ods reporter + stats_reporter_config + report_interval: int: report interval in train iteration if gather_ssd_cache_stats is enabled + use_passed_in_path: bool: whether to use passed in path for rocksdb shard or default one on SSD mount path + l2_cache_size: Optional[int]: size in GB for l2 cache size per tbe + max_l1_cache_size: Optional[int]: size in MB for max allocated l1 cache size per tbe + enable_async_update: Optional[bool]: whether to enable async update for l2 cache + bulk_init_chunk_size: int: number of rows to insert into rocksdb in each chunk + lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE + + # Parameter Server (PS) Attributes + ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses + and ports. Example: (("::1", 2000), ("::1", 2001), ("::1", 2002)). + Reason for using tuple is we want it hashable. + ps_client_thread_num (int): Number of threads to use for PS client + ps_max_key_per_request (int): Maximum number of keys to send per request + ps_max_local_index_length(int): Maximum local index length + """ + + ssd_storage_directory: Optional[str] = None + ssd_rocksdb_write_buffer_size: Optional[int] = None + ssd_rocksdb_shards: Optional[int] = None + gather_ssd_cache_stats: Optional[bool] = None + stats_reporter_config: Optional[TBEStatsReporterConfig] = None + use_passed_in_path: bool = True + l2_cache_size: Optional[int] = None # size in GB + max_l1_cache_size: Optional[int] = None # size in MB + enable_async_update: Optional[bool] = None # enable L2 cache async update + bulk_init_chunk_size: Optional[int] = None # number of rows + lazy_bulk_init_enabled: Optional[bool] = None # enable lazy bulk init + + # Parameter Server (PS) Attributes + ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None + ps_client_thread_num: Optional[int] = None + ps_max_key_per_request: Optional[int] = None + ps_max_local_index_length: Optional[int] = None + + def __hash__(self) -> int: + return hash( + ( + self.ssd_storage_directory, + self.ssd_rocksdb_write_buffer_size, + self.ssd_rocksdb_shards, + # Parameter Server (PS) Attributes + self.ps_hosts, + self.ps_client_thread_num, + self.ps_max_key_per_request, + self.ps_max_local_index_length, + # tbe attributes + self.gather_ssd_cache_stats, + self.stats_reporter_config, + self.l2_cache_size, + self.max_l1_cache_size, + self.enable_async_update, + self.bulk_init_chunk_size, + self.lazy_bulk_init_enabled, + ) + ) + + @dataclass class ParameterSharding: """ - Describes the sharding of the parameter. + Describes (configures) the sharding of a parameter, which usually corresponds to a (feature) table. - sharding_type (str): how this parameter is sharded. See ShardingType for well-known - types. + sharding_type (str): how this parameter is sharded. See ShardingType for well-known types. compute_kernel (str): compute kernel to be used by this parameter. ranks (Optional[List[int]]): rank of each shard. sharding_spec (Optional[ShardingSpec]): list of ShardMetadata for each shard. + cache_params (Optional[CacheParams]): cache params for embedding lookup. + enforce_hbm (Optional[bool]): whether to use HBM. + stochastic_rounding (Optional[bool]): whether to use stochastic rounding. + bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode. + output_dtype (Optional[DataType]): output dtype. + key_value_params (Optional[KeyValueParams]): key value params for SSD TBE or PS. NOTE: ShardingType.TABLE_WISE - rank where this embedding is placed - ShardingType.COLUMN_WISE - rank where the embedding shards are placed, seen as - individual tables + ShardingType.COLUMN_WISE - rank where the embedding shards are placed, seen as individual tables ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused @@ -415,22 +718,61 @@ class ParameterSharding: compute_kernel: str ranks: Optional[List[int]] = None sharding_spec: Optional[ShardingSpec] = None + cache_params: Optional[CacheParams] = None + enforce_hbm: Optional[bool] = None + stochastic_rounding: Optional[bool] = None + bounds_check_mode: Optional[BoundsCheckMode] = None + output_dtype: Optional[DataType] = None + key_value_params: Optional[KeyValueParams] = None -ModuleShardingPlan = Dict[str, ParameterSharding] -""" -Map of ParameterSharding per parameter (usually a table). This describes the sharding plan for a torchrec module (e.g. `EmbeddingBagCollection`) -""" +class EmbeddingModuleShardingPlan(ModuleShardingPlan, Dict[str, ParameterSharding]): + """ + Map of ParameterSharding per parameter (usually a table). This describes the sharding plan for a torchrec module (e.g. `EmbeddingBagCollection`) + """ + + def __str__(self) -> str: + out = "" + param_table = [] + shard_table = [] + for param_name, param_sharding in self.items(): + param_table.append( + [ + param_name, + param_sharding.sharding_type, + param_sharding.compute_kernel, + param_sharding.ranks, + ] + ) + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): + shards = param_sharding.sharding_spec.shards + if shards is not None: + for shard in shards: + shard_table.append( + [ + param_name, + shard.shard_offsets, + shard.shard_sizes, + shard.placement, + ] + ) + out += "\n\n" + _tabulate( + param_table, ["param", "sharding type", "compute kernel", "ranks"] + ) + out += "\n\n" + _tabulate( + shard_table, ["param", "shard offsets", "shard sizes", "placement"] + ) + return out @dataclass class ShardingPlan: """ Representation of sharding plan. This uses the FQN of the larger wrapped model (i.e the model that is wrapped using `DistributedModelParallel`) - ModuleShardingPlan should be used when TorchRec composability is desired. + EmbeddingModuleShardingPlan should be used when TorchRec composability is desired. Attributes: - plan (Dict[str, ModuleShardingPlan]): dict keyed by module path of + plan (Dict[str, EmbeddingModuleShardingPlan]): dict keyed by module path of dict of parameter sharding specs keyed by parameter name. """ @@ -447,14 +789,20 @@ def get_plan_for_module(self, module_path: str) -> Optional[ModuleShardingPlan]: return self.plan.get(module_path, None) def __str__(self) -> str: - return str(self.plan) + out = "" + for i, (module_path, module_plan) in enumerate(self.plan.items()): + if i > 0: + out += "\n\n" + out += "module: " + module_path + out += str(module_plan) + return out ShardedModuleContext = Multistreamable class NullShardedModuleContext(Multistreamable): - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: Optional[torch.Stream]) -> None: pass # pyre-ignore [2] @@ -473,10 +821,20 @@ def __init__( world_size: int, rank: int, pg: Optional[dist.ProcessGroup] = None, + output_dtensor: bool = False, ) -> None: self.world_size = world_size self.rank = rank self.process_group: Optional[dist.ProcessGroup] = pg + self.device_mesh: Optional[DeviceMesh] = ( + init_device_mesh( + device_type=_get_pg_default_device(pg).type, + mesh_shape=(dist.get_world_size(pg),), + ) + if pg + else None + ) + self.output_dtensor: bool = output_dtensor @classmethod def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv": @@ -499,42 +857,89 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": return cls(world_size, rank, None) -class FeatureShardingMixIn: - """ - Feature Sharding Interface to provide sharding-aware feature metadata. +class ShardingEnv2D(ShardingEnv): """ + Creates a sharding environment for 2D parallelism, enables usage of 2D parallelism in sharding + by seamlessly switching to the sub process group (sharding_pg) for a rank. This class is used + as source of truth for TorchRec to understand if we're in a 2D parallel environment. - def id_list_feature_names(self) -> List[str]: - raise NotImplementedError - - def id_score_list_feature_names(self) -> List[str]: - raise NotImplementedError + NOTE: + - global pg is part of `process_group` attribute to keep the same API as ShardingEnv, + some parts of TorchRec require the global pg to work appropriately (ie: `DDPWrapper` in `DistributedModelParallel`) + - `world_size` and `rank` attributes return values relative to `sharding_pg`, this is different + from default ShardingEnv returning values relative to `global_pg` - def id_list_feature_names_per_rank(self) -> List[List[str]]: - raise NotImplementedError + Attributes: + sharding_pg: The process group containing the ranks to shard on. + global_pg: The process group representing global ranks. + device_mesh: A 2D device mesh representing the topology of the global world size + on "replicate" and "shard" dimensions. + node_group_size (Optional[int]): The size of each node group. If not provided, it will be inferred + from env var `LOCAL_WORLD_SIZE`. + """ - def id_score_list_feature_names_per_rank(self) -> List[List[str]]: - raise NotImplementedError + def __init__( + self, + sharding_pg: dist.ProcessGroup, + global_pg: dist.ProcessGroup, + device_mesh: DeviceMesh, + node_group_size: Optional[int] = None, + use_inter_host_allreduce: bool = False, + ) -> None: + assert device_mesh.ndim == 2, "DeviceMesh must be two dimensional!" + self.world_size: int = dist.get_world_size(sharding_pg) + self.global_world_size: int = dist.get_world_size(global_pg) + self.rank: int = dist.get_rank(sharding_pg) + self.global_rank: int = dist.get_rank(global_pg) + self.process_group: dist.ProcessGroup = ( + global_pg # to keep consistent naming between ShardingEnv and ShardingEnv2D + ) + self.sharding_pg: dist.ProcessGroup = sharding_pg + self.device_mesh: DeviceMesh = device_mesh + self.node_group_size: Optional[int] = node_group_size + self.output_dtensor: bool = True + self.use_inter_host_allreduce: bool = use_inter_host_allreduce + + def num_sharding_groups(self) -> int: + """ + Return number of sharding groups, also known as the number of times model parallel is replicated + """ + return self.global_world_size // self.world_size - def id_list_features_per_rank(self) -> List[int]: - raise NotImplementedError + def remap_rank(self, rank: int, sharding_type: ShardingType) -> int: + """ + Remap from current rank to the appropriate rank in a continuous [0, ..., world size] array for the given sharding type. - def id_score_list_features_per_rank(self) -> List[int]: - raise NotImplementedError + Args: + rank (int): rank to remap. + sharding_type (ShardingType): sharding type to remap to. + Returns: + int: remapped rank. + """ + if sharding_type in ( + ShardingType.COLUMN_WISE, + ShardingType.TABLE_WISE, + ShardingType.GRID_SHARD, + ): + return ( + rank % self.world_size + if self.use_inter_host_allreduce + else rank // self.num_sharding_groups() + ) + else: + raise ValueError( + f"Do not need 2D specific remapping logic for sharding type: {sharding_type}" + ) -class ModuleShardingMixIn: - """ - The interface to access a sharded module's sharding scheme. - """ - @property - def shardings(self) -> Dict[str, FeatureShardingMixIn]: - raise NotImplementedError +class NullShardingContext(Multistreamable): + def record_stream(self, stream: torch.Stream) -> None: + pass Out = TypeVar("Out") -CompIn = TypeVar("CompIn", bound=Multistreamable) +CompIn = TypeVar("CompIn") DistOut = TypeVar("DistOut") ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) @@ -544,7 +949,6 @@ class ShardedModule( nn.Module, Generic[CompIn, DistOut, Out, ShrdCtx], ModuleNoCopyMixin, - ModuleShardingMixIn, ): """ All model-parallel modules implement this interface. @@ -558,6 +962,8 @@ class ShardedModule( from data-parallel to model parallel and vise-versa. """ + _FORCE_STATE_DICT_LOAD = True + @abc.abstractmethod def __init__( self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None @@ -570,10 +976,6 @@ def __init__( qcomm_codecs_registry = {} self._qcomm_codecs_registry = qcomm_codecs_registry - self._input_dists: List[nn.Module] = [] - self._lookups: List[nn.Module] = [] - self._output_dists: List[nn.Module] = [] - @abc.abstractmethod def create_context(self) -> ShrdCtx: pass @@ -590,7 +992,7 @@ def input_dist( *input, # pyre-ignore[2] **kwargs, - ) -> Awaitable[CompIn]: + ) -> Awaitable[Awaitable[CompIn]]: pass @abc.abstractmethod @@ -625,42 +1027,41 @@ def forward(self, *input, **kwargs) -> LazyAwaitable[Out]: LazyAwaitable[Out]: awaitable of output from output dist. """ ctx = self.create_context() - dist_input = self.input_dist(ctx, *input, **kwargs).wait() + dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait() return self.compute_and_output_dist(ctx, dist_input) - def sparse_grad_parameter_names( - self, - destination: Optional[List[str]] = None, - prefix: str = "", - ) -> List[str]: - destination = [] if destination is None else destination - return destination - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key - def extra_repr(self) -> str: + @property + @abc.abstractmethod + def unsharded_module_type(self) -> Type[nn.Module]: """ - Pretty prints representation of the module's lookup modules, input_dists and output_dists + This property is added as part of dynamic sharding implementation. + + When resharding an already-sharded module wrapped in DMP, the unsharded + module type is needed to identify the proper sharder to reshard. This is + due to DistributedModelParellel (DMP) references module Sharders based + on the unsharded module type. """ + ... - def loop(key: str, modules: List[nn.Module]) -> List[str]: - child_lines = [] - if len(modules) > 0: - child_lines.append("(" + key + "): ") - for module in modules: - mod_str = repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append(mod_str) - return child_lines - rep = [] - rep.extend(loop("lookups", self._lookups)) - rep.extend(loop("_input_dists", self._input_dists)) - rep.extend(loop("_output_dists", self._output_dists)) +def get_tensor_size_bytes(t: torch.Tensor) -> int: + b: int = t.numel() * t.element_size() + if isinstance(t, UInt4Tensor): + assert ( + b % 2 == 0 + ), f"UInt4Tensor must have number of elements that is divisible by 2, got {t.numel()}" + b = b // 2 + elif isinstance(t, UInt2Tensor): + assert ( + b % 4 == 0 + ), f"UInt2Tensor must have number of elements that is divisible by 4, got {t.numel()}" + b = b // 4 - return "\n ".join(rep) + return b class ModuleSharder(abc.ABC, Generic[M]): @@ -678,14 +1079,18 @@ def __init__( torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self._qcomm_codecs_registry = qcomm_codecs_registry + # pyre-fixme[56]: Pyre doesn't yet support decorators with ParamSpec applied to + # generic functions. Consider using a context manager instead of a decorator, if + # possible. @abc.abstractclassmethod # pyre-ignore [3] def shard( self, module: M, - params: ModuleShardingPlan, + params: EmbeddingModuleShardingPlan, env: ShardingEnv, device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, ) -> ShardedModule[Any, Any, Any, Any]: """ Does the actual sharding. It will allocate parameters on the requested locations @@ -695,10 +1100,11 @@ def shard( Args: module (M): module to shard. - params (ModuleShardingPlan): dict of fully qualified parameter names + params (EmbeddingModuleShardingPlan): dict of fully qualified parameter names (module path + parameter name, '.'-separated) to its sharding spec. env (ShardingEnv): sharding environment that has the process group. - device (torch.device): compute device. + device (Optional[torch.device]): compute device. + path (Optional[str]): fully qualified name of the module. used for trace annotations in embedding modules Returns: ShardedModule[Any, Any, Any]: sharded module implementation. @@ -707,8 +1113,7 @@ def shard( @property @abc.abstractmethod - def module_type(self) -> Type[M]: - ... + def module_type(self) -> Type[M]: ... @property def qcomm_codecs_registry(self) -> Optional[Dict[str, QuantizedCommCodecs]]: @@ -743,12 +1148,14 @@ def storage_usage( compute kernel. """ - assert compute_device_type in {"cuda", "cpu"} - storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} - return { - storage_map[compute_device_type].value: tensor.element_size() - * tensor.nelement() + assert compute_device_type in {"cuda", "cpu", "mtia"} + storage_map = { + "cuda": ParameterStorage.HBM, + "cpu": ParameterStorage.DDR, + # TODO: Update it later. Setting for MTIA is same as CPU's for now. + "mtia": ParameterStorage.DDR, } + return {storage_map[compute_device_type].value: get_tensor_size_bytes(tensor)} class ShardingPlanner(abc.ABC): @@ -792,3 +1199,43 @@ def collective_plan( ShardingPlan: the computed sharding plan. """ ... + + +def rank_device(device_type: str, rank: int) -> torch.device: + if device_type == "cpu": + return torch.device("cpu") + + return torch.device(f"{device_type}:{rank}") + + +class ObjectPoolShardingType(Enum): + """ + Sharding type for object pool + """ + + ROW_WISE = "row_wise" + # across nodes, state will be replicated. On lookup, all2alls will happen intranode. + # State is synced via update a2a being global internode. + REPLICATED_ROW_WISE = "replicated_row_wise" + + +@dataclass +class ObjectPoolShardingPlan(ModuleShardingPlan): + sharding_type: ObjectPoolShardingType + inference: bool = False + + +@dataclass +class ShardingBucketMetadata: + """ + If a table is row-wise sharded with bucketization, this class contains the bucket information for the table. + + Attributes: + num_buckets_per_shard (List[int]): Number of buckets in each shard of the table. + bucket_offsets_per_shard (List[int]): Index of the first bucket in each shard. + bucket_size (int): No. of rows in one bucket. + """ + + num_buckets_per_shard: List[int] + bucket_offsets_per_shard: List[int] + bucket_size: int diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 7dbcf1d9d..04a5afe0a 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -5,22 +5,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy +import logging +import pdb # noqa +import sys from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union +from contextlib import AbstractContextManager, nullcontext +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType from torch import nn +from torch.autograd.profiler import record_function from torchrec import optim as trec_optim -from torchrec.distributed.types import ShardedModule +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.types import ( + DataType, + EmbeddingEvent, + ParameterSharding, + ShardedModule, + ShardingBucketMetadata, + ShardingType, + ShardMetadata, +) +from torchrec.modules.embedding_configs import data_type_to_sparse_type from torchrec.types import CopyMixIn - +logger: logging.Logger = logging.getLogger(__name__) _T = TypeVar("_T") - """ torch.package safe functions from pyre_extensions. However, pyre_extensions is not safe to use in code that will be torch.packaged, as it requires sys for @@ -64,7 +81,7 @@ def filter_state_dict( filtered_state_dict = OrderedDict() for key, value in state_dict.items(): - if key.startswith(name): + if key.startswith(name + "."): # + 1 to length is to remove the '.' after the key filtered_state_dict[key[len(name) + 1 :]] = value return filtered_state_dict @@ -228,7 +245,26 @@ def _copy_if_device_match(tensor: torch.Tensor) -> torch.Tensor: # if this is a sharded module, customize the copy if isinstance(copy_module, CopyMixIn): return copy_module.copy(to_device) - + copied_param = { + name: torch.nn.Parameter( + _copy_if_device_match(param.data), requires_grad=param.requires_grad + ) + for name, param in copy_module.named_parameters(recurse=False) + } + copied_buffer = { + name: _copy_if_device_match(buffer) + for name, buffer in copy_module.named_buffers(recurse=False) + } + for name, param in copied_param.items(): + m = copy_module + if "." in name: + continue + m.register_parameter(name, param) + for name, buffer in copied_buffer.items(): + m = copy_module + if "." in name: + continue + m.register_buffer(name, buffer) for child_name, child in copy_module.named_children(): if not any([isinstance(submodule, CopyMixIn) for submodule in child.modules()]): child_copy = child._apply(_copy_if_device_match) @@ -238,6 +274,33 @@ def _copy_if_device_match(tensor: torch.Tensor) -> torch.Tensor: return copy_module +class CopyableMixin(nn.Module): + """ + Allows copying of module to a target device. + + Example:: + + class MyModule(CopyableMixin): + ... + + Args: + device : torch.device to copy to + + Returns + nn.Module on new device + """ + + def copy( + self, + device: torch.device, + ) -> nn.Module: + return copy_to_device( + self, + current_device=torch.device("cpu"), + to_device=device, + ) + + def optimizer_type_to_emb_opt_type( optimizer_class: Type[torch.optim.Optimizer], ) -> Optional[EmbOptimType]: @@ -269,7 +332,6 @@ def merge_fused_params( fused_params: Optional[Dict[str, Any]] = None, param_fused_params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - """ Configure the fused_params including cache_precision if the value is not preset. @@ -294,3 +356,255 @@ def merge_fused_params( _fused_params = copy.deepcopy(fused_params) _fused_params.update(param_fused_params) return _fused_params + + +def add_params_from_parameter_sharding( + fused_params: Optional[Dict[str, Any]], + parameter_sharding: ParameterSharding, +) -> Dict[str, Any]: + """ + Extract params from parameter sharding and then add them to fused_params. + + Params from parameter sharding will override the ones in fused_params if they + exist already. + + Args: + fused_params (Optional[Dict[str, Any]]): the existing fused_params + parameter_sharding (ParameterSharding): the parameter sharding to use + + Returns: + [Dict[str, Any]]: the fused_params dictionary with params from parameter + sharding added. + + """ + if fused_params is None: + fused_params = {} + + # update fused_params using params from parameter_sharding + # this will take precidence over the fused_params provided from sharders + if parameter_sharding.cache_params is not None: + cache_params = parameter_sharding.cache_params + if cache_params.algorithm is not None: + fused_params["cache_algorithm"] = cache_params.algorithm + if cache_params.load_factor is not None: + fused_params["cache_load_factor"] = cache_params.load_factor + if cache_params.reserved_memory is not None: + fused_params["cache_reserved_memory"] = cache_params.reserved_memory + if cache_params.precision is not None: + fused_params["cache_precision"] = cache_params.precision + if cache_params.prefetch_pipeline is not None: + fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline + if cache_params.multipass_prefetch_config is not None: + fused_params["multipass_prefetch_config"] = ( + cache_params.multipass_prefetch_config + ) + + if parameter_sharding.enforce_hbm is not None: + fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm + + if parameter_sharding.stochastic_rounding is not None: + fused_params["stochastic_rounding"] = parameter_sharding.stochastic_rounding + + if parameter_sharding.bounds_check_mode is not None: + fused_params["bounds_check_mode"] = parameter_sharding.bounds_check_mode + + if parameter_sharding.output_dtype is not None: + fused_params["output_dtype"] = parameter_sharding.output_dtype + + if ( + parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value} + and parameter_sharding.key_value_params is not None + ): + kv_params = parameter_sharding.key_value_params + key_value_params_dict = asdict(kv_params) + key_value_params_dict = { + k: v for k, v in key_value_params_dict.items() if v is not None + } + if kv_params.stats_reporter_config: + key_value_params_dict["stats_reporter_config"] = ( + kv_params.stats_reporter_config + ) + fused_params.update(key_value_params_dict) + + # print warning if sharding_type is data_parallel or kernel is dense + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + logger.warning( + f"Sharding Type is {parameter_sharding.sharding_type}, " + "caching params will be ignored" + ) + elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value: + logger.warning( + f"Compute Kernel is {parameter_sharding.compute_kernel}, " + "caching params will be ignored" + ) + + # calling `get_additional_fused_params` for customized kernel + # it will be updated to the `fused_params` dict + if hasattr( + parameter_sharding, "get_additional_fused_params" + ) and parameter_sharding.compute_kernel in { + EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value + }: + # type: ignore[attr-defined] + fused_params.update(parameter_sharding.get_additional_fused_params()) + + return fused_params + + +def convert_to_fbgemm_types(fused_params: Dict[str, Any]) -> Dict[str, Any]: + if "cache_precision" in fused_params: + if isinstance(fused_params["cache_precision"], DataType): + fused_params["cache_precision"] = data_type_to_sparse_type( + fused_params["cache_precision"] + ) + + if "weights_precision" in fused_params: + if isinstance(fused_params["weights_precision"], DataType): + fused_params["weights_precision"] = data_type_to_sparse_type( + fused_params["weights_precision"] + ) + + if "output_dtype" in fused_params: + if isinstance(fused_params["output_dtype"], DataType): + fused_params["output_dtype"] = data_type_to_sparse_type( + fused_params["output_dtype"] + ) + + return fused_params + + +def init_parameters(module: nn.Module, device: torch.device) -> None: + with torch.no_grad(): + has_meta_param = any(t.is_meta for t in module.parameters()) + not_on_target_device = any(t.device != device for t in module.parameters()) + if not_on_target_device: + module.to_empty(device=device) if has_meta_param else module.to(device) + + def maybe_reset_parameters(m: nn.Module) -> None: + if hasattr(m, "reset_parameters"): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + m.reset_parameters() + + module.apply(maybe_reset_parameters) + + +def maybe_annotate_embedding_event( + event: EmbeddingEvent, + module_fqn: Optional[str], + sharding_type: Optional[str], + # pyre-fixme[24]: Generic type `AbstractContextManager` expects 2 type parameters, + # received 1. +) -> AbstractContextManager[None]: + if module_fqn and sharding_type: + annotation = f"[{event.value}]_[{module_fqn}]_[{sharding_type}]" + return record_function(annotation) + else: + return nullcontext() + + +class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used from a forked multiprocessing child. + Useful in debugging multiprocessed code + + Example:: + + from torchrec.multiprocessing_utils import ForkedPdb + + if dist.get_rank() == 0: + ForkedPdb().set_trace() + dist.barrier() + """ + + # pyre-ignore + def interaction(self, *args, **kwargs) -> None: + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") # noqa + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +def create_global_tensor_shape_stride_from_metadata( + parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None +) -> Tuple[torch.Size, Tuple[int, int]]: + """ + Create a global tensor shape and stride from shard metadata. + + Returns: + torch.Size: global tensor shape. + tuple: global tensor stride. + """ + size = None + if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value: + # pyre-ignore[16] + row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] + col_dim = 0 + for shard in parameter_sharding.sharding_spec.shards: + col_dim += shard.shard_sizes[1] + size = torch.Size([row_dim, col_dim]) + elif ( + parameter_sharding.sharding_type == ShardingType.ROW_WISE.value + or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value + ): + row_dim = 0 + col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for shard in parameter_sharding.sharding_spec.shards: + row_dim += shard.shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value: + size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes) + elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value: + # we need node group size to appropriately calculate global shape from shard + assert devices_per_node is not None + row_dim, col_dim = 0, 0 + num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node + for _ in range(num_cw_shards): + col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for _ in range(devices_per_node): + row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + # pyre-ignore[7] + return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) + + +def get_bucket_metadata_from_shard_metadata( + shards: List[ShardMetadata], + num_buckets: int, +) -> ShardingBucketMetadata: + """ + Calculate the bucket metadata from shard metadata. + + This function assumes the table is to be row-wise sharded in equal sized buckets across bucket boundaries. + It computes the number of buckets per shard and the bucket size. + + Args: + shards (List[ShardMetadata]): Shard metadata for all shards of a table. + num_buckets (int): The number of buckets to divide the table into. + + Returns: + ShardingBucketMetadata: An object containing the number of buckets per shard and the bucket size. + """ + assert len(shards) > 0, "Shards cannot be empty" + table_size = shards[-1].shard_offsets[0] + shards[-1].shard_sizes[0] + assert ( + table_size % num_buckets == 0 + ), f"Table size '{table_size}' must be divisible by num_buckets '{num_buckets}'" + bucket_size = table_size // num_buckets + bucket_metadata: ShardingBucketMetadata = ShardingBucketMetadata( + num_buckets_per_shard=[], bucket_offsets_per_shard=[], bucket_size=bucket_size + ) + current_bucket_offset = 0 + for shard in shards: + assert ( + len(shard.shard_offsets) == 1 or shard.shard_offsets[1] == 0 + ), f"Shard shard_offsets[1] '{shard.shard_offsets[1]}' is not 0. Table should be only row-wise sharded for bucketization" + assert ( + shard.shard_sizes[0] % bucket_size == 0 + ), f"Shard size[0] '{shard.shard_sizes[0]}' is not divisible by bucket size '{bucket_size}'" + num_buckets_in_shard = shard.shard_sizes[0] // bucket_size + bucket_metadata.num_buckets_per_shard.append(num_buckets_in_shard) + bucket_metadata.bucket_offsets_per_shard.append(current_bucket_offset) + current_bucket_offset += num_buckets_in_shard + + return bucket_metadata diff --git a/torchrec/fx/__init__.py b/torchrec/fx/__init__.py index 2b6059fb8..7f86c74e1 100644 --- a/torchrec/fx/__init__.py +++ b/torchrec/fx/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Tracer Custom FX tracer for torchrec diff --git a/torchrec/fx/tests/test_tracer.py b/torchrec/fx/tests/test_tracer.py index b0a0bdb9d..cf6f49049 100644 --- a/torchrec/fx/tests/test_tracer.py +++ b/torchrec/fx/tests/test_tracer.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch diff --git a/torchrec/fx/tracer.py b/torchrec/fx/tracer.py index 51537675a..c219cea4a 100644 --- a/torchrec/fx/tracer.py +++ b/torchrec/fx/tracer.py @@ -5,13 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, Optional, Union +# pyre-strict + +import typing +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.node import Argument -from torchrec.distributed.types import NoWait +from torchrec.distributed.types import LazyAwaitable, NoWait +from torchrec.fx.utils import dmp_fx_trace_forward _is_fx_tracing_flag = False @@ -31,8 +35,17 @@ class Tracer(torch.fx.Tracer): TorchScript if needed """ - def __init__(self) -> None: + def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: super().__init__() + self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + Override FX definition to include quantized embedding bags + """ + if type(m).__name__ in self._leaf_modules: + return True + return super().is_leaf_module(m, module_qualified_name) @compatibility(is_backward_compatible=True) def trace( @@ -44,11 +57,37 @@ def trace( global _is_fx_tracing_flag old_is_fx_tracing_flag = _is_fx_tracing_flag _is_fx_tracing_flag = True + try: - graph = super().trace( - root, - concrete_args, - ) + # TODO(ivankobzarev): support DMP not only on the root level + from torchrec.distributed.model_parallel import DistributedModelParallel + + if isinstance(root, DistributedModelParallel): + # In the case where the module is wrapped in DMP, you need to replace DMP's forward + # call with a new signature, one with explicit args, because fx can't handle variable args. + # Furthermore, we need to provide the `fn_root` argument because when tracing a function, + # fx uses an empty module as the root (unless one is explicitly provided), which leads to + # issues with path_of_module and named_buffers. + + # TODO(shababayub): This can be removed if we either stop supporting dmp wrapping + # for fx trace or strip dmp name in named_modules path (much like named_buffers). + if isinstance(root, torch.nn.Module): + for prefix, module in root.named_modules(): + # TODO(T140754678): Remove this workaround to _fx_path + module._fx_path = prefix + + dmp = root + graph = super().trace( + root=dmp_fx_trace_forward(dmp, self), + concrete_args=concrete_args, + ) + self.root._dmp_wrapped_module = dmp._dmp_wrapped_module + else: + # Unwrapped dmp modules and composibility api will enter here. + graph = super().trace( + root=root, + concrete_args=concrete_args, + ) finally: _is_fx_tracing_flag = old_is_fx_tracing_flag return graph @@ -71,17 +110,42 @@ def create_arg(self, a: Any) -> Argument: return self.create_node( "call_function", target=NoWait, - args=self.create_arg((a._obj,)), + # Ugh. This line seems to be triggering some bug in pyre - so + # cast instead of fixme. + args=typing.cast( + Tuple[torch.fx.node.Argument, ...], self.create_arg((a._obj,)) + ), kwargs={}, type_expr=NoWait, ) + + # Not equivalent to when LazyAwaitable.wait() is called in eager. Here can be called earlier, as attr was not requested and this is not guranteed to be torch function + # TODO(ivankobzarev): support equivalent timing of LazyAwaitable + if isinstance(a, LazyAwaitable): + if a._result is None: + a._result = a.wait() + return super().create_arg(a._result) + return super().create_arg(a) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Allows trace-ability of non registered modules. This is typically used for Table Batched Embeddings + made to look like nn.EmbeddingBags + """ + + if hasattr(mod, "_fx_path"): + # pyre-fixme[7]: Expected `str` but got `Union[Tensor, Module]`. + return mod._fx_path + else: + return super().path_of_module(mod) + def symbolic_trace( # pyre-ignore[24] root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None, + leaf_modules: Optional[List[str]] = None, ) -> torch.fx.GraphModule: """ Symbolic tracing API @@ -99,6 +163,8 @@ def symbolic_trace( Returns: GraphModule: a Module created from the recorded operations from ``root``. """ - tracer = Tracer() + tracer = Tracer(leaf_modules) graph = tracer.trace(root, concrete_args) - return torch.fx.GraphModule(root, graph) + # Ugh. This line seems to be triggering some bug in pyre - so cast instead + # of fixme. + return torch.fx.GraphModule(typing.cast(torch.nn.Module, root), graph) diff --git a/torchrec/fx/utils.py b/torchrec/fx/utils.py new file mode 100644 index 000000000..55d206e33 --- /dev/null +++ b/torchrec/fx/utils.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +from typing import Any, Dict, List, Set + +import torch + +from torch.fx._symbolic_trace import is_fx_tracing + +# Not importing DistributedModelParallel here to avoid circular dependencies as DMP depends on torchrec.fx.tracer +# def dmp_fx_trace_forward(dmp: DistributedModelParallel) + + +# pyre-ignore +def fake_range(): + # pyre-fixme[16]: Module `_C` has no attribute `_jit_tree_views`. + return torch._C._jit_tree_views.SourceRangeFactory("", None, 0, 0).make_raw_range( + 0, 1 + ) + + +# pyre-ignore +def dmp_fx_trace_forward( # noqa: C901 + # pyre-ignore + dmp, + tracer: torch.fx.Tracer, +): + func = dmp._dmp_wrapped_module.forward + sign: inspect.Signature = inspect.signature(func) + + module_to_type_str: Dict[str, Set[str]] = {} + + def add_if_missing(module: str, type_str: str) -> None: + if module not in module_to_type_str: + _set = set() + _set.add(type_str) + module_to_type_str[module] = _set + else: + s = module_to_type_str[module] + if type_str not in s: + s.add(type_str) + + def torch_no_import(t: torch.Type) -> bool: + return isinstance( + t, (torch.FloatType, torch.IntType, torch.ComplexType, torch.StringType) + ) + + def torch_typing(t: torch.Type) -> bool: + return isinstance( + t, + ( + torch.TupleType, + torch.ListType, + torch.DictType, + torch.OptionalType, + torch.AnyType, + ), + ) + + exec_imports = [] + args_call = ", ".join([f"{p.name}" for p in sign.parameters.values()]) + + types = [] + try: + args_decls: List[str] = [] + for p in sign.parameters.values(): + pann = p.annotation + + ptype = torch.jit.annotations.try_ann_to_type(pann, fake_range()) + types.append(ptype) + args_decls.append(f"{p.name}: {ptype}") + + while len(types) > 0: + t = types.pop() + if torch_no_import(t): + continue + + t_base_name = f"{t}".split("[")[0] + if torch_typing(t): + add_if_missing("typing", t_base_name) + else: + if hasattr(t, "__module__") and not torch_no_import(t): + m = t.__module__ + add_if_missing(f"{m}", f"{t}".split("[")[0]) + + if hasattr(t, "containedTypes"): + contained_types = getattr(t, "containedTypes", None)() + for ctype in contained_types: + types.append(ctype) + + if hasattr(t, "getElementType"): + el_type = getattr(t, "getElementType", None)() + + args_decl = ", ".join(args_decls) + + for m, s in module_to_type_str.items(): + ts = ", ".join(s) + exec_imports.append(f"from {m} import {ts}") + except Exception as e: + print(f"Exception:{e}") + # Catching here if source is not available to proceed hoping that jit will infer correct types without annotations. + # Often it fails here when can not access to dataclass generated __init__ + args_decl = args_call + + exec_def_fn_name = "__fx_forward" + exec_dmp_wrapper_local_name = "_dmp_wrapped_module_local" + _dmp_wrapped_module_local = dmp + locals_dict = locals() + exec_def = f"def {exec_def_fn_name}({args_decl}):\n return {exec_dmp_wrapper_local_name}({args_call})" + + exec_imports_str = "\n".join(exec_imports) + pycode = f"{exec_imports_str}\n{exec_def}" + + exec(pycode, locals_dict) # noqa: P204 Allow use of exec + + wrapper = locals_dict[exec_def_fn_name] + wrapper.__signature__ = sign + + return wrapper + + +@torch.fx.wrap +# pyre-ignore +def _fx_marker(s: str, any_proxy_unused: Any) -> None: + pass + + +# pyre-ignore +def fx_marker(s: str, any_proxy_unused: Any) -> None: + if is_fx_tracing(): + _fx_marker(s, any_proxy_unused) + + +def is_marker_node(node: torch.fx.Node, marker_name: str) -> bool: + # bool() syntax for pyre + return bool( + node.op == "call_function" + and node.target == _fx_marker + and isinstance(node.args[0], str) + and node.args[0] == marker_name + ) + + +@torch.jit.ignore +def assert_fx_safe(condition: bool, message: str) -> None: + if not is_fx_tracing(): + assert condition, message diff --git a/torchrec/inference/CMakeLists.txt b/torchrec/inference/CMakeLists.txt index 794bb1490..3b3ac2a8d 100644 --- a/torchrec/inference/CMakeLists.txt +++ b/torchrec/inference/CMakeLists.txt @@ -4,110 +4,60 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.13 FATAL_ERROR) -project(inference) +cmake_minimum_required(VERSION 3.8) -# This step is crucial to ensure that the -# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. -# e.g. ~/gprc/examples/cpp/cmake/common.cmake -include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) +project(inference C CXX) +include(/home/paulzhan/grpc/examples/cpp/cmake/common.cmake) -# abi and other flags -if(DEFINED GLIBCXX_USE_CXX11_ABI) - if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") - set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") - endif() -endif() +# Proto file +get_filename_component(hw_proto "/home/paulzhan/torchrec/torchrec/inference/protos/predictor.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) -# keep it static for now since folly-shared version is broken -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") -# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.h") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") -# dependencies -find_package(Boost REQUIRED) -find_package(Torch REQUIRED) -find_package(folly REQUIRED) -find_package(gflags REQUIRED) - -include_directories(${Torch_INCLUDE_DIRS}) -include_directories(${folly_INCLUDE_DIRS}) -include_directories(${PYTORCH_FMT_INCLUDE_PATH}) +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") set(CMAKE_CXX_STANDARD 17) -# torch deploy library -add_library(torch_deploy_internal STATIC - ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o - ${DEPLOY_SRC_PATH}/deploy.cpp - ${DEPLOY_SRC_PATH}/loader.cpp - ${DEPLOY_SRC_PATH}/path_environment.cpp - ${DEPLOY_SRC_PATH}/elf_file.cpp) - -# For python builtins. caffe2_interface_library properly -# makes use of the --whole-archive option. -target_link_libraries(torch_deploy_internal PRIVATE - crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw -) -target_link_libraries(torch_deploy_internal - PUBLIC shm torch ${PYTORCH_LIB_FMT} -) -caffe2_interface_library(torch_deploy_internal torch_deploy) - -# inference library - -# for our own header files -include_directories(include/) -include_directories(gen/) - -# define our library target -add_library(inference STATIC - src/Batching.cpp - src/BatchingQueue.cpp - src/GPUExecutor.cpp - src/ResultSplit.cpp - src/Exception.cpp - src/ResourceManager.cpp -) - -# -rdynamic is needed to link against the static library -target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" - dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} -) - -# for generated protobuf - -# grpc headers. e.g. ~/.local/include -include_directories(${GRPC_HEADER_INCLUDE_PATH}) - -set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") -set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") -set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") -set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") +# Torch + FBGEMM +find_package(Torch REQUIRED) +add_library( fbgemm SHARED IMPORTED GLOBAL ) +set_target_properties(fbgemm PROPERTIES IMPORTED_LOCATION ${FBGEMM_LIB}) -add_library(pred_grpc_proto STATIC - ${pred_grpc_srcs} - ${pred_grpc_hdrs} - ${pred_proto_srcs} - ${pred_proto_hdrs}) -target_link_libraries(pred_grpc_proto +add_library(hw_grpc_proto STATIC + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) +target_link_libraries(hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF}) -# server +# Targets greeter_[async_](client|server) add_executable(server server.cpp) target_link_libraries(server - inference - torch_deploy - pred_grpc_proto "${TORCH_LIBRARIES}" - ${FOLLY_LIBRARIES} - ${PYTORCH_LIB_FMT} - ${FBGEMM_LIB} + fbgemm + hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} - ${_PROTOBUF_LIBPROTOBUF}) + ${_PROTOBUF_LIBPROTOBUF} +) diff --git a/torchrec/inference/README.md b/torchrec/inference/README.md index 8958fc72d..0a8e1cc17 100644 --- a/torchrec/inference/README.md +++ b/torchrec/inference/README.md @@ -1,116 +1,65 @@ # TorchRec Inference Library (**Experimental** Release) ## Overview -TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. +TorchRec Inference is a C++ library that supports **gpu inference**. Previously, the TorchRec inference library was authored with torch.package and torch.deploy, which are old and deprecated. All the previous files live under the directory inference_legacy for reference. -Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. +TorchRec inference was reauthored with simplicity in mind, while also reflecting the current production environment for RecSys models, namely torch.fx for graph capturing/tracing and TorchScript for model inference in a C++ environment. The inference solution here is meant to serve as a simple reference and example, not a fully scaled out solution for production use cases. The current solution demonstrates converting the DLRM model in Python to TorchScript, running a C++ inference server with the model on a GPU, and sending requests to said server via a python client. -## Example +## Requirements -C++ 17 is a requirement. +C++ 17 is a requirement. GCC version has to be >= 9, with initial testing done on GCC 9.
### **1. Install Dependencies** - -Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy -is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: - +1. [GRPC for C++][https://grpc.io/docs/languages/cpp/quickstart/] needs to be installed, with the resulting installation directory being `$HOME/.local` +2. Ensure that **the protobuf compiler (protoc) binary being used is from the GRPC installation above**. The protoc binary will live in `$HOME/.local/bin`, which may not match with the system protoc binary, can check with `which protoc`. +3. Install PyTorch, FBGEMM, and TorchRec (ideally in a virtual environment): ``` -sudo nvidia-docker build -t torchrec . -sudo nvidia-docker run -it torchrec:latest +pip install torch --index-url https://download.pytorch.org/whl/cu121 +pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121 +pip install torchmetrics==1.0.3 +pip install torchrec --index-url https://download.pytorch.org/whl/cu121 ``` + ### **2. Set variables** Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. ``` -# provide the cmake prefix path of pytorch, folly, and fmt. -# fmt and boost are pulled from folly's installation in this example. -export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" -export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" -export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" - -# provide fmt from pytorch for torch deploy -export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" -export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" - -# provide necessary info to link to torch deploy -export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" -export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" - -# provide common.cmake from grpc/examples, makes linking to grpc easier -export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" -export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" - -# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators -export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" - -# provide path to python packages for torch deploy runtime -export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" -``` - -Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. +# provide fbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +find $HOME -name fbgemm_gpu_py.so +# Use path from correct virtual environment above and set environment variable $FBGEMM_LIB to it +export FBGEMM_LIB="" ``` -# double-conversion, fmt and gflags are pulled from folly's installation in this example -export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" -export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" -export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" -export PYTORCH_LIB_PATH="~/pytorch/build/lib/" -export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" -export LIBRARY_PATH="$PYTORCH_LIB_PATH" -``` +### **3. Generate TorchScripted DLRM model** -### **3. Package DLRM model** - -The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement -`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and -implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read -more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. - -`/torchrec/examples/inference/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference/dlrm_predict.py`). -`DLRMPredictModule` is packaged for inference in the following example. +Here, we generate the DLRM model in Torchscript and save it for model loading later on. ``` git clone https://github.com/pytorch/torchrec.git -cd ~/torchrec/examples/inference/ -python dlrm_packager.py --output_path /tmp/model_package.zip +cd ~/torchrec/torchrec/inference/ +python3 dlrm_packager.py --output_path /tmp/model.pt ``` - ### **4. Build inference library and example server** -Generate protobuf C++ and Python code from protobuf +Generate Python code from protobuf for client and build the server. ``` -cd ~/torchrec/inference/ -mkdir -p gen/torchrec/inference - -# C++ (server) -protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto - -protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto - - # Python (client) -python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +python -m grpc_tools.protoc -I protos --python_out=. --grpc_python_out=. protos/predictor.proto ``` -Build inference library and example server +Build server and C++ protobufs ``` -cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" --DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ --DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ --DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ --DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ --DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ --DFBGEMM_LIB="$FBGEMM_LIB" +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');" -DFBGEMM_LIB="$FBGEMM_LIB" cd build make -j @@ -119,28 +68,20 @@ make -j ### **5. Run server and client** -Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +Start the server, loading in the model saved previously ``` -CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +./server /tmp/model.pt ``` **output** -In the logs, a plan should be outputted by the Torchrec planner: - -``` -INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # -INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # -INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # -INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # -INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # -INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # -INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # -INFO:.torchrec.distributed.planner.stats:# quant: 26 # +In the logs, you should see: + +``` +Loading model... +Sanity Check with dummy inputs + Model Forward Completed, Output: 0.489247 +Server listening on 0.0.0.0:50051 ```` `nvidia-smi` output should also show allocation of the model onto the gpu: @@ -155,7 +96,7 @@ INFO:.torchrec.distributed.planner.stats:# quant: 26 +-----------------------------------------------------------------------------+ ``` -Make a request to the server via the client: +In another terminal instance, make a request to the server via the client: ``` python client.py @@ -166,74 +107,3 @@ python client.py ``` Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] ``` - -
- -## Planned work - -- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference -- In-code documentation -- Simplify installation process - -
- -## Potential issues and solutions - -Skip this section if you had no issues with installation or running the example. - -**Missing header files during pytorch installation** - -If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: - -``` -~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory - #include - ^~~~~~~~ -compilation terminated. -[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o -ninja: build stopped: subcommand failed. -``` - -To get these header files, install `cudatoolkit-dev`: -``` -conda install -c conda-forge cudatoolkit-dev -``` - -Re-run the installation after this. - -**libdouble-conversion missing** -``` -~/torchrec/torchrec/inference/build$ ./example -./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory -``` - -If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: - -``` -sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ -libdouble-conversion.so.3 -``` - -**Two installations of glog** -``` -~/torchrec/torchrec/inference/build$ ./example -ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and -'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). -``` -The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was -previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. - -To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. - -**Undefined symbols with std::string or cxx11** - -If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely -that your dependencies were compiled with different ABI values. Re-compile your dependencies -and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. - -The ABI value of pytorch can be checked via: - -``` -import torch -torch._C._GLIBCXX_USE_CXX11_ABI -``` diff --git a/torchrec/inference/__init__.py b/torchrec/inference/__init__.py index 670f2af78..a0ce8680a 100644 --- a/torchrec/inference/__init__.py +++ b/torchrec/inference/__init__.py @@ -5,21 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Torchrec Inference - -Torchrec inference provides a Torch.Deploy based library for GPU inference. - -These includes: - - Model packaging in Python - - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. - - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. - - Model serving in C++ - - `BatchingQueue` is a generalized config-based request tensor batching implementation. - - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. - -We implemented an example of how to use this library with the TorchRec DLRM model. - - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. - - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. -""" +# pyre-strict from . import model_packager, modules # noqa # noqa diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py index a9126f96c..50bdc09ea 100644 --- a/torchrec/inference/client.py +++ b/torchrec/inference/client.py @@ -7,12 +7,10 @@ import argparse import logging -import os import grpc +import predictor_pb2, predictor_pb2_grpc import torch -from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc -from torch.utils.data import DataLoader from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset from torchrec.datasets.utils import Batch @@ -21,18 +19,13 @@ def create_training_batch(args: argparse.Namespace) -> Batch: return next( iter( - DataLoader( - RandomRecDataset( - keys=DEFAULT_CAT_NAMES, - batch_size=args.batch_size, - hash_size=args.num_embedding_features, - ids_per_feature=1, - num_dense=len(DEFAULT_INT_NAMES), - ), - batch_sampler=None, - pin_memory=False, - num_workers=0, - ) + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), ) ) diff --git a/torchrec/inference/dlrm_packager.py b/torchrec/inference/dlrm_packager.py new file mode 100644 index 000000000..560e31c16 --- /dev/null +++ b/torchrec/inference/dlrm_packager.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import sys +from typing import List + +from dlrm_predict import create_training_batch, DLRMModelConfig, DLRMPredictFactory +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec dlrm model packager") + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default="45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138," + "10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35", + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--sparse_feature_names", + type=str, + default=",".join(DEFAULT_CAT_NAMES), + help="Comma separated names of the sparse features.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--num_dense_features", + type=int, + default=len(DEFAULT_INT_NAMES), + help="Number of dense features.", + ) + parser.add_argument( + "--output_path", + type=str, + help="Output path of model package.", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + """ + Use torch.package to package the torchrec DLRM Model. + + Args: + argv (List[str]): command line args. + + Returns: + None. + """ + + args = parse_args(argv) + + args.batch_size = 10 + args.num_embedding_features = 26 + batch = create_training_batch(args) + + model_config = DLRMModelConfig( + dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))), + dense_in_features=args.num_dense_features, + embedding_dim=args.embedding_dim, + id_list_features_keys=args.sparse_feature_names.split(","), + num_embeddings_per_feature=list( + map(int, args.num_embeddings_per_feature.split(",")) + ), + num_embeddings=args.num_embeddings, + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + sample_input=batch, + ) + + script_module = DLRMPredictFactory(model_config).create_predict_module(world_size=1) + + script_module.save(args.output_path) + print(f"Package is saved to {args.output_path}") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/inference/dlrm_predict.py b/torchrec/inference/dlrm_predict.py new file mode 100644 index 000000000..78c2ea3fe --- /dev/null +++ b/torchrec/inference/dlrm_predict.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import fbgemm_gpu.sparse_ops # noqa: F401 + +import torch +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.fx.tracer import Tracer +from torchrec.inference.modules import ( + PredictFactory, + PredictModule, + quantize_inference_model, + shard_quant_model, +) +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_training_batch(args) -> Batch: + return RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ).batch_generator._generate_batch() + + +# OSS Only + + +@dataclass +class DLRMModelConfig: + """ + Model Config for specifying DLRM model parameters. + """ + + dense_arch_layer_sizes: List[int] + dense_in_features: int + embedding_dim: int + id_list_features_keys: List[str] + num_embeddings_per_feature: List[int] + num_embeddings: int + over_arch_layer_sizes: List[int] + sample_input: Batch + + +class DLRMPredictModule(PredictModule): + """ + nn.Module to wrap DLRM model to use for inference. + + Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (List[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (List[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + id_list_features_keys (List[str]): the names of the sparse features. Used to + construct a batch for inference. + dense_device: (Optional[torch.device]). + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + id_list_features_keys: List[str], + dense_device: Optional[torch.device] = None, + ) -> None: + module = DLRM( + embedding_bag_collection=embedding_bag_collection, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=dense_device, + ) + super().__init__(module, dense_device) + + self.id_list_features_keys: List[str] = id_list_features_keys + + def predict_forward( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Args: + batch (Dict[str, torch.Tensor]): currently expects input dense features + to be mapped to the key "float_features" and input sparse features + to be mapped to the key "id_list_features". + + Returns: + Dict[str, torch.Tensor]: output of inference. + """ + + try: + logits = self.predict_module( + batch["float_features"], + KeyedJaggedTensor( + keys=self.id_list_features_keys, + lengths=batch["id_list_features.lengths"], + values=batch["id_list_features.values"], + ), + ) + predictions = logits.sigmoid() + except Exception as e: + logger.info(e) + raise e + + # Flip predictions tensor to be 1D. TODO: Determine why prediction shape + # can be 2D at times (likely due to input format?) + predictions = predictions.reshape( + [ + predictions.size()[0], + ] + ) + + return { + "default": predictions.to(torch.device("cpu"), non_blocking=True).float() + } + + +class DLRMPredictFactory(PredictFactory): + """ + Factory Class for generating TorchScript DLRM Model for C++ inference. + + Args: + model_config (DLRMModelConfig): model config + + """ + + def __init__(self, model_config: DLRMModelConfig) -> None: + self.model_config = model_config + + def create_predict_module(self, world_size: int, device: str) -> torch.nn.Module: + logging.basicConfig(level=logging.INFO) + set_propogate_device(True) + + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=self.model_config.embedding_dim, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate( + self.model_config.id_list_features_keys + ) + ] + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + + module = DLRMPredictModule( + embedding_bag_collection=ebc, + dense_in_features=self.model_config.dense_in_features, + dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, + over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, + id_list_features_keys=self.model_config.id_list_features_keys, + dense_device=device, + ) + + quant_model = quantize_inference_model(module) + sharded_model, _ = shard_quant_model( + quant_model, compute_device=device, sharding_device=device + ) + + batch = {} + batch["float_features"] = self.model_config.sample_input.dense_features.to( + device + ) + batch["id_list_features.lengths"] = ( + self.model_config.sample_input.sparse_features.lengths().to(device) + ) + batch["id_list_features.values"] = ( + self.model_config.sample_input.sparse_features.values().to(device) + ) + + sharded_model(batch) + + tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"]) + + graph = tracer.trace(sharded_model) + gm = torch.fx.GraphModule(sharded_model, graph) + + gm(batch) + scripted_gm = torch.jit.script(gm) + scripted_gm(batch) + return scripted_gm + + def batching_metadata(self) -> Dict[str, str]: + return { + "float_features": "dense", + "id_list_features": "sparse", + } + + def result_metadata(self) -> str: + return "dict_of_tensor" + + def run_weights_independent_tranformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + return predict_module + + def run_weights_dependent_transformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + """ + Run transformations that depends on weights of the predict module. e.g. lowering to a backend. + """ + return predict_module diff --git a/torchrec/inference/include/torchrec/inference/Batching.h b/torchrec/inference/include/torchrec/inference/Batching.h index c73240d90..2c5e8837c 100644 --- a/torchrec/inference/include/torchrec/inference/Batching.h +++ b/torchrec/inference/include/torchrec/inference/Batching.h @@ -30,7 +30,7 @@ class BatchingFunc { public: virtual ~BatchingFunc() = default; - virtual std::unordered_map batch( + virtual std::unordered_map batch( const std::string& /* featureName */, const std::vector>& /* requests */, const int64_t& /* totalNumBatch */, @@ -52,21 +52,23 @@ C10_DECLARE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc); REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY( \ name, c10::REGISTRY_DEFAULT, __VA_ARGS__); -std::unordered_map combineFloat( +std::unordered_map combineFloat( const std::string& featureName, const std::vector>& requests); -std::unordered_map combineSparse( +std::unordered_map combineSparse( const std::string& featureName, const std::vector>& requests, bool isWeighted); -std::unordered_map combineEmbedding( +std::unordered_map combineEmbedding( const std::string& featureName, const std::vector>& requests); -std::unordered_map moveToDevice( - std::unordered_map combined, +void moveIValueToDevice(c10::IValue& val, const c10::Device& device); + +std::unordered_map moveToDevice( + std::unordered_map combined, const c10::Device& device); } // namespace torchrec diff --git a/torchrec/inference/include/torchrec/inference/Exception.h b/torchrec/inference/include/torchrec/inference/Exception.h index c42da1700..5667d1ad2 100644 --- a/torchrec/inference/include/torchrec/inference/Exception.h +++ b/torchrec/inference/include/torchrec/inference/Exception.h @@ -7,21 +7,43 @@ */ #pragma once +#include -#include -#include - -#include +namespace torchrec { -#include "torchrec/inference/Types.h" +// We have different error code defined for different kinds of exceptions in +// fblearner/sigrid predictor. (Code pointer: +// fblearner/predictor/if/prediction_service.thrift.) We define different +// exception type here so that in fblearner/sigrid predictor we can detect the +// exception type and return the corresponding error code to reflect the right +// info. +class TorchrecException : public std::runtime_error { + public: + explicit TorchrecException(const std::string& error) + : std::runtime_error(error) {} +}; -namespace torchrec { +// GPUOverloadException maps to +// PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT +class GPUOverloadException : public TorchrecException { + public: + explicit GPUOverloadException(const std::string& error) + : TorchrecException(error) {} +}; -void handleRequestException( - folly::Promise>& promise, - const std::string& msg); -void handleBatchException( - std::vector& contexts, - const std::string& msg); +// GPUExecutorOverloadException maps to +// PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT +class GPUExecutorOverloadException : public TorchrecException { + public: + explicit GPUExecutorOverloadException(const std::string& error) + : TorchrecException(error) {} +}; +// TorchDeployException maps to +// PredictorUserErrorCode::TORCH_DEPLOY_ERROR +class TorchDeployException : public TorchrecException { + public: + explicit TorchDeployException(const std::string& error) + : TorchrecException(error) {} +}; } // namespace torchrec diff --git a/torchrec/inference/src/Exception.cpp b/torchrec/inference/include/torchrec/inference/ExceptionHandler.h similarity index 76% rename from torchrec/inference/src/Exception.cpp rename to torchrec/inference/include/torchrec/inference/ExceptionHandler.h index 908c1c30b..491acf48f 100644 --- a/torchrec/inference/src/Exception.cpp +++ b/torchrec/inference/include/torchrec/inference/ExceptionHandler.h @@ -6,28 +6,33 @@ * LICENSE file in the root directory of this source tree. */ -#include "torchrec/inference/Exception.h" +#pragma once #include #include +#include + +#include "torchrec/inference/Exception.h" #include "torchrec/inference/Types.h" namespace torchrec { - +template void handleRequestException( folly::Promise>& promise, const std::string& msg) { - auto ex = folly::make_exception_wrapper(msg); + auto ex = folly::make_exception_wrapper(msg); auto response = std::make_unique(); response->exception = std::move(ex); promise.setValue(std::move(response)); } + +template void handleBatchException( std::vector& contexts, const std::string& msg) { for (auto& context : contexts) { - handleRequestException(context.promise, msg); + handleRequestException(context.promise, msg); } } diff --git a/torchrec/inference/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/include/torchrec/inference/GPUExecutor.h index 02eb7fd80..d2d289670 100644 --- a/torchrec/inference/include/torchrec/inference/GPUExecutor.h +++ b/torchrec/inference/include/torchrec/inference/GPUExecutor.h @@ -9,6 +9,8 @@ #pragma once #include +#include +#include #include #include @@ -30,12 +32,22 @@ #include "torchrec/inference/BatchingQueue.h" #include "torchrec/inference/Observer.h" #include "torchrec/inference/ResultSplit.h" -#include "torchrec/inference/include/torchrec/inference/Observer.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" // @manual namespace torchrec { class GPUExecutor { public: + // Used to interface with python's garbage collector + struct GCConfig { + bool optimizationEnabled = false; + size_t collectionFreq = 1000; + size_t statReportingFreq = 10000; + std::unique_ptr observer = + std::make_unique(); + std::map threadIdToNumForwards = std::map(); + }; + GPUExecutor( std::shared_ptr manager, torch::deploy::ReplicatedObj model, @@ -46,7 +58,8 @@ class GPUExecutor { std::shared_ptr observer, // shared_ptr because used in completion executor callback std::function warmupFn = {}, - c10::optional numThreadsPerGPU = c10::nullopt); + std::optional numThreadsPerGPU = std::nullopt, + std::unique_ptr gcConfig = std::make_unique()); GPUExecutor(GPUExecutor&& executor) noexcept = default; GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default; ~GPUExecutor(); @@ -71,7 +84,16 @@ class GPUExecutor { std::shared_ptr observer_; std::function warmupFn_; + std::mutex warmUpMutex_; + std::mutex warmUpAcquireSessionMutex_; + std::condition_variable warmUpCV_; + int warmUpCounter_{0}; + size_t numThreadsPerGPU_; + + std::unique_ptr gcConfig_; + + void reportGCStats(c10::IValue stats); }; } // namespace torchrec diff --git a/torchrec/inference/include/torchrec/inference/Observer.h b/torchrec/inference/include/torchrec/inference/Observer.h index b16f855ec..14ac3bceb 100644 --- a/torchrec/inference/include/torchrec/inference/Observer.h +++ b/torchrec/inference/include/torchrec/inference/Observer.h @@ -13,20 +13,34 @@ namespace torchrec { +// Record generic timeseries stat with a key +class IDynamicTimeseriesObserver { + public: + virtual void addCount(uint32_t value, std::string key) = 0; + + virtual ~IDynamicTimeseriesObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyDynamicTimeseriesObserver : public IDynamicTimeseriesObserver { + public: + void addCount(uint32_t /* value */, std::string /* key */) override {} +}; + class IBatchingQueueObserver { public: // Record the amount of time an entry of PredictionRequests // in the batching queue waits before they are read and allocated // onto a GPU device. virtual void recordBatchingQueueLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Record the amount of time it takes for a batching function // to execute. virtual void recordBatchingFuncLatency( - double value, + uint32_t value, std::string batchingFuncName, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; @@ -34,30 +48,30 @@ class IBatchingQueueObserver { // Record the amount of time it takes to create a batch of // requests. virtual void recordBatchCreationLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Increment the number of batching queue timeouts experienced. - virtual void addBatchingQueueTimeoutCount(double value) = 0; + virtual void addBatchingQueueTimeoutCount(uint32_t value) = 0; // Increment the number of times a GPU could not be chosen // for allocation. - virtual void addGPUBusyCount(double value) = 0; + virtual void addGPUBusyCount(uint32_t value) = 0; // Increment the number of requests entering the batching queue. - virtual void addRequestsCount(double value) = 0; + virtual void addRequestsCount(uint32_t value) = 0; // Increment the number of bytes of tensors moved to cuda. - virtual void addBytesMovedToGPUCount(double value) = 0; + virtual void addBytesMovedToGPUCount(uint32_t value) = 0; // Increment the number of batches processed by the batching // queue (moved onto the GPU executor). - virtual void addBatchesProcessedCount(double value) = 0; + virtual void addBatchesProcessedCount(uint32_t value) = 0; // Increment the number of requests processed by the batching // queue (moved onto the GPU executor). - virtual void addRequestsProcessedCount(double value) = 0; + virtual void addRequestsProcessedCount(uint32_t value) = 0; // The obervations that should be made when a batch is completed. virtual void observeBatchCompletion( @@ -75,29 +89,29 @@ class IBatchingQueueObserver { class EmptyBatchingQueueObserver : public IBatchingQueueObserver { public: void recordBatchingQueueLatency( - double /* value */, + uint32_t /* value */, std::chrono::steady_clock::time_point /* now */) override {} void recordBatchingFuncLatency( - double /* value */, + uint32_t /* value */, std::string /* batchingFuncName */, std::chrono::steady_clock::time_point /* now */) override {} void recordBatchCreationLatency( - double /* value */, + uint32_t /* value */, std::chrono::steady_clock::time_point /* now */) override {} - void addBatchingQueueTimeoutCount(double /* value */) override {} + void addBatchingQueueTimeoutCount(uint32_t /* value */) override {} - void addGPUBusyCount(double /* value */) override {} + void addGPUBusyCount(uint32_t /* value */) override {} - void addRequestsCount(double /* value */) override {} + void addRequestsCount(uint32_t /* value */) override {} - void addBytesMovedToGPUCount(double /* value */) override {} + void addBytesMovedToGPUCount(uint32_t /* value */) override {} - void addBatchesProcessedCount(double /* value */) override {} + void addBatchesProcessedCount(uint32_t /* value */) override {} - void addRequestsProcessedCount(double /* value */) override {} + void addRequestsProcessedCount(uint32_t /* value */) override {} }; class IGPUExecutorObserver { @@ -105,60 +119,60 @@ class IGPUExecutorObserver { // Record the amount of time a batch spends in the GPU Executor // queue. virtual void recordQueueLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Record the latency of prediction (forward call, H2D). virtual void recordPredictionLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Record the latency of device to host transfer facilitated // by result split function. virtual void recordDeviceToHostLatency( - double value, + uint32_t value, std::string resultSplitFuncName, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Record the latency of splitting the result. virtual void recordResultSplitLatency( - double value, + uint32_t value, std::string resultSplitFuncName, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Record the latency from enqueue to completion. virtual void recordTotalLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; // Increment the number of GPUExecutor queue timeouts. - virtual void addQueueTimeoutCount(double value) = 0; + virtual void addQueueTimeoutCount(uint32_t value) = 0; // Increment the number of predict exceptions. - virtual void addPredictionExceptionCount(double value) = 0; + virtual void addPredictionExceptionCount(uint32_t value) = 0; // Increment the number of batches successfully processed. - virtual void addBatchesProcessedCount(double value) = 0; + virtual void addBatchesProcessedCount(uint32_t value) = 0; virtual ~IGPUExecutorObserver() {} }; class ISingleGPUExecutorObserver { public: - virtual void addRequestsCount(double value) = 0; - virtual void addRequestProcessingExceptionCount(double value) = 0; + virtual void addRequestsCount(uint32_t value) = 0; + virtual void addRequestProcessingExceptionCount(uint32_t value) = 0; virtual void recordQueueLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point = std::chrono::steady_clock::now()) = 0; virtual void recordRequestProcessingLatency( - double value, + uint32_t value, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) = 0; @@ -166,15 +180,15 @@ class ISingleGPUExecutorObserver { }; class EmptySingleGPUExecutorObserver : public ISingleGPUExecutorObserver { - void addRequestsCount(double) override {} - void addRequestProcessingExceptionCount(double) override {} + void addRequestsCount(uint32_t) override {} + void addRequestProcessingExceptionCount(uint32_t) override {} void recordQueueLatency( - double, + uint32_t, std::chrono::steady_clock::time_point = std::chrono::steady_clock::now()) override {} void recordRequestProcessingLatency( - double, + uint32_t, std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()) override {} }; @@ -183,32 +197,76 @@ class EmptySingleGPUExecutorObserver : public ISingleGPUExecutorObserver { class EmptyGPUExecutorObserver : public IGPUExecutorObserver { public: void recordQueueLatency( - double /* value */, + uint32_t /* value */, std::chrono::steady_clock::time_point /* now */) override {} void recordPredictionLatency( - double /* value */, + uint32_t /* value */, std::chrono::steady_clock::time_point /* now */) override {} void recordDeviceToHostLatency( - double /* value */, + uint32_t /* value */, std::string /* resultSplitFuncName */, std::chrono::steady_clock::time_point /* now */) override {} void recordResultSplitLatency( - double /* value */, + uint32_t /* value */, std::string /* resultSplitFuncName */, std::chrono::steady_clock::time_point /* now */) override {} void recordTotalLatency( - double /* value */, + uint32_t /* value */, std::chrono::steady_clock::time_point /* now */) override {} - void addQueueTimeoutCount(double /* value */) override {} + void addQueueTimeoutCount(uint32_t /* value */) override {} + + void addPredictionExceptionCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} +}; + +class IResourceManagerObserver { + public: + // Add the number of requests in flight for a gpu + virtual void addOutstandingRequestsCount(uint32_t value, int gpuIdx) = 0; - void addPredictionExceptionCount(double /* value */) override {} + // Add the most in flight requests on a gpu ever + virtual void addAllTimeHighOutstandingCount(uint32_t value, int gpuIdx) = 0; - void addBatchesProcessedCount(double /* value */) override {} + // Record the latency for finding a device + virtual void addWaitingForDeviceLatency( + uint32_t value, + int gpuIdx, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Recording all stats related to resource manager at once. + virtual void recordAllStats( + uint32_t outstandingRequests, + uint32_t allTimeHighOutstanding, + uint32_t waitedForMs, + int gpuIdx) { + addOutstandingRequestsCount(outstandingRequests, gpuIdx); + addAllTimeHighOutstandingCount(allTimeHighOutstanding, gpuIdx); + addWaitingForDeviceLatency(waitedForMs, gpuIdx); + } + + virtual ~IResourceManagerObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyResourceManagerObserver : public IResourceManagerObserver { + public: + void addOutstandingRequestsCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addAllTimeHighOutstandingCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addWaitingForDeviceLatency( + uint32_t /* value */, + int /* gpuIdx */, + std::chrono::steady_clock::time_point /* now */) override {} }; // Helper for determining how much time has elapsed in milliseconds since a diff --git a/torchrec/inference/include/torchrec/inference/ResourceManager.h b/torchrec/inference/include/torchrec/inference/ResourceManager.h index b9bd3eb99..d3dd1ea18 100644 --- a/torchrec/inference/include/torchrec/inference/ResourceManager.h +++ b/torchrec/inference/include/torchrec/inference/ResourceManager.h @@ -18,6 +18,8 @@ #include #include +#include "torchrec/inference/Observer.h" + namespace torchrec { /** @@ -29,7 +31,9 @@ class ResourceManager { ResourceManager( int worldSize, size_t maxOutstandingBatches, - int logFrequency = 100); + int logFrequency = 100, + std::unique_ptr observer = + std::make_unique()); // Returns whether batches can be allocated onto a device based on // slack provided (ms) and maxOutstandingBatches_). @@ -45,6 +49,7 @@ class ResourceManager { const int logFrequency_; // Align as 64B to avoid false sharing alignas(64) std::mutex mu_; + std::unique_ptr observer_; }; class ResourceManagerGuard { diff --git a/torchrec/inference/include/torchrec/inference/ResultSplit.h b/torchrec/inference/include/torchrec/inference/ResultSplit.h index af262214d..2c3ef2463 100644 --- a/torchrec/inference/include/torchrec/inference/ResultSplit.h +++ b/torchrec/inference/include/torchrec/inference/ResultSplit.h @@ -52,4 +52,17 @@ c10::IValue splitDictOfTensors( c10::IValue splitDictWithMaskTensor(c10::IValue result, size_t nOffset, size_t nLength); +class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { + public: + virtual std::string name() override; + + virtual c10::IValue splitResult( + c10::IValue result, + size_t offset, + size_t length, + size_t /* nTotalLength */) override; + + c10::IValue moveToHost(c10::IValue result) override; +}; + } // namespace torchrec diff --git a/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h b/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h index cef83bc8c..9da63d7c2 100644 --- a/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h +++ b/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h @@ -33,7 +33,9 @@ class SingleGPUExecutor { size_t numGpu, std::shared_ptr observer = std::make_shared(), - c10::Device resultDevice = c10::kCPU); + c10::Device resultDevice = c10::kCPU, + size_t numProcessThreads = 1u, + bool useHighPriCudaStream = false); // Moveable only SingleGPUExecutor(SingleGPUExecutor&& executor) noexcept = default; @@ -48,12 +50,14 @@ class SingleGPUExecutor { std::shared_ptr manager_; const ExecInfos execInfos_; const size_t numGpu_; + const size_t numProcessThreads_; + const bool useHighPriCudaStream_; const c10::Device resultDevice_; std::shared_ptr observer_; folly::MPMCQueue> requests_; + std::unique_ptr processExecutor_; std::unique_ptr completionExecutor_; std::atomic roundRobinExecInfoNextIdx_; - std::thread processThread_; }; } // namespace torchrec diff --git a/torchrec/inference/include/torchrec/inference/TestUtils.h b/torchrec/inference/include/torchrec/inference/TestUtils.h index 51449ac9d..1150e2c23 100644 --- a/torchrec/inference/include/torchrec/inference/TestUtils.h +++ b/torchrec/inference/include/torchrec/inference/TestUtils.h @@ -28,6 +28,9 @@ createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); JaggedTensor createJaggedTensor(const std::vector>& input); +c10::List createIValueList( + const std::vector>& input); + at::Tensor createEmbeddingTensor( const std::vector>& input); diff --git a/torchrec/inference/include/torchrec/inference/Types.h b/torchrec/inference/include/torchrec/inference/Types.h index c4cf0ecb6..0a09d1f35 100644 --- a/torchrec/inference/include/torchrec/inference/Types.h +++ b/torchrec/inference/include/torchrec/inference/Types.h @@ -62,6 +62,8 @@ struct PredictionResponse { struct RequestContext { uint32_t batchSize; folly::Promise> promise; + // folly request context for request tracking in crochet + std::shared_ptr follyRequestContext; }; using PredictionException = std::runtime_error; @@ -85,7 +87,7 @@ struct PredictionBatch : public boost::noncopyable { size_t batchSize; - c10::Dict forwardArgs; + c10::impl::GenericDict forwardArgs; std::vector contexts; @@ -100,7 +102,7 @@ struct PredictionBatch : public boost::noncopyable { // noncopyable struct and not trigger copy-constructor. PredictionBatch( size_t bs, - c10::Dict fa, + c10::impl::GenericDict fa, std::vector ctxs, std::unique_ptr rmg = nullptr) : batchSize(bs), @@ -112,14 +114,29 @@ struct PredictionBatch : public boost::noncopyable { std::string methodNameArg, std::vector argsArg, folly::Promise> promise) - : methodName(std::move(methodNameArg)), args(std::move(argsArg)) { + : methodName(std::move(methodNameArg)), + args(std::move(argsArg)), + forwardArgs( + c10::impl::GenericDict(at::StringType::get(), at::AnyType::get())) { contexts.push_back(RequestContext{1u, std::move(promise)}); } + size_t sizeOfIValue(const c10::IValue& val) const { + size_t size = 0; + if (val.isTensor()) { + size += val.toTensor().storage().nbytes(); + } else if (val.isList()) { + for (const auto& v : val.toListRef()) { + size += sizeOfIValue(v); + } + } + return size; + } + inline size_t size() const { size_t size = 0; for (const auto& iter : forwardArgs) { - size += iter.value().storage().nbytes(); + size += sizeOfIValue(iter.value()); } return size; } diff --git a/torchrec/inference/include/torchrec/inference/Validation.h b/torchrec/inference/include/torchrec/inference/Validation.h index 96ad54caf..74a2f20ff 100644 --- a/torchrec/inference/include/torchrec/inference/Validation.h +++ b/torchrec/inference/include/torchrec/inference/Validation.h @@ -20,7 +20,7 @@ namespace torchrec { bool validateSparseFeatures( at::Tensor& values, at::Tensor& lengths, - c10::optional maybeWeights = c10::nullopt); + std::optional maybeWeights = std::nullopt); // Returns whether dense features are valid. // Currently validates: diff --git a/torchrec/inference/inference_legacy/CMakeLists.txt b/torchrec/inference/inference_legacy/CMakeLists.txt new file mode 100644 index 000000000..794bb1490 --- /dev/null +++ b/torchrec/inference/inference_legacy/CMakeLists.txt @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) +project(inference) + +# This step is crucial to ensure that the +# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. +# e.g. ~/gprc/examples/cpp/cmake/common.cmake +include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) + + +# abi and other flags + +if(DEFINED GLIBCXX_USE_CXX11_ABI) + if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") + endif() +endif() + +# keep it static for now since folly-shared version is broken +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") + +# dependencies +find_package(Boost REQUIRED) +find_package(Torch REQUIRED) +find_package(folly REQUIRED) +find_package(gflags REQUIRED) + +include_directories(${Torch_INCLUDE_DIRS}) +include_directories(${folly_INCLUDE_DIRS}) +include_directories(${PYTORCH_FMT_INCLUDE_PATH}) + +set(CMAKE_CXX_STANDARD 17) + +# torch deploy library +add_library(torch_deploy_internal STATIC + ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o + ${DEPLOY_SRC_PATH}/deploy.cpp + ${DEPLOY_SRC_PATH}/loader.cpp + ${DEPLOY_SRC_PATH}/path_environment.cpp + ${DEPLOY_SRC_PATH}/elf_file.cpp) + +# For python builtins. caffe2_interface_library properly +# makes use of the --whole-archive option. +target_link_libraries(torch_deploy_internal PRIVATE + crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw +) +target_link_libraries(torch_deploy_internal + PUBLIC shm torch ${PYTORCH_LIB_FMT} +) +caffe2_interface_library(torch_deploy_internal torch_deploy) + +# inference library + +# for our own header files +include_directories(include/) +include_directories(gen/) + +# define our library target +add_library(inference STATIC + src/Batching.cpp + src/BatchingQueue.cpp + src/GPUExecutor.cpp + src/ResultSplit.cpp + src/Exception.cpp + src/ResourceManager.cpp +) + +# -rdynamic is needed to link against the static library +target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" + dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} +) + +# for generated protobuf + +# grpc headers. e.g. ~/.local/include +include_directories(${GRPC_HEADER_INCLUDE_PATH}) + +set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") +set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") +set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") +set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") + +add_library(pred_grpc_proto STATIC + ${pred_grpc_srcs} + ${pred_grpc_hdrs} + ${pred_proto_srcs} + ${pred_proto_hdrs}) + +target_link_libraries(pred_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) + +# server +add_executable(server server.cpp) +target_link_libraries(server + inference + torch_deploy + pred_grpc_proto + "${TORCH_LIBRARIES}" + ${FOLLY_LIBRARIES} + ${PYTORCH_LIB_FMT} + ${FBGEMM_LIB} + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) diff --git a/torchrec/inference/inference_legacy/README.md b/torchrec/inference/inference_legacy/README.md new file mode 100644 index 000000000..fc3d8afcb --- /dev/null +++ b/torchrec/inference/inference_legacy/README.md @@ -0,0 +1,239 @@ +# TorchRec Inference Library (**Experimental** Release) + +## Overview +TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. + +Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. + +## Example + +C++ 17 is a requirement. + +
+ +### **1. Install Dependencies** + +Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy +is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: + +``` +sudo nvidia-docker build -t torchrec . +sudo nvidia-docker run -it torchrec:latest +``` + +### **2. Set variables** + +Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. + +``` +# provide the cmake prefix path of pytorch, folly, and fmt. +# fmt and boost are pulled from folly's installation in this example. +export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" +export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" +export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" + +# provide fmt from pytorch for torch deploy +export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" +export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" + +# provide necessary info to link to torch deploy +export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" +export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" + +# provide common.cmake from grpc/examples, makes linking to grpc easier +export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" +export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" + +# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" + +# provide path to python packages for torch deploy runtime +export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" +``` + +Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. + +``` +# double-conversion, fmt and gflags are pulled from folly's installation in this example +export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" +export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" +export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" +export PYTORCH_LIB_PATH="~/pytorch/build/lib/" + +export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" +export LIBRARY_PATH="$PYTORCH_LIB_PATH" +``` + +### **3. Package DLRM model** + +The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement +`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and +implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read +more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. + +`/torchrec/examples/inference_legacy/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference_legacy/dlrm_predict.py`). +`DLRMPredictModule` is packaged for inference in the following example. + +``` +git clone https://github.com/pytorch/torchrec.git + +cd ~/torchrec/examples/inference_legacy/ +python dlrm_packager.py --output_path /tmp/model_package.zip +``` + + + +### **4. Build inference library and example server** + +Generate protobuf C++ and Python code from protobuf + +``` +cd ~/torchrec/inference/ +mkdir -p gen/torchrec/inference + +# C++ (server) +protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto + +protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto + + +# Python (client) +python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +``` + + +Build inference library and example server +``` +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" +-DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ +-DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ +-DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ +-DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ +-DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ +-DFBGEMM_LIB="$FBGEMM_LIB" + +cd build +make -j +``` + + +### **5. Run server and client** + +Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +``` +CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +``` + +**output** + +In the logs, a plan should be outputted by the Torchrec planner: + +``` +INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # +INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # +INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # +INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # +INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # +INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # +INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # +INFO:.torchrec.distributed.planner.stats:# quant: 26 # +```` + +`nvidia-smi` output should also show allocation of the model onto the gpu: + +``` ++-----------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=============================================================================| +| 0 N/A N/A 86668 C ./example 1357MiB | ++-----------------------------------------------------------------------------+ +``` + +Make a request to the server via the client: + +``` +python client.py +``` + +**output** + +``` +Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] +``` + +
+ +## Planned work + +- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference +- In-code documentation +- Simplify installation process + +
+ +## Potential issues and solutions + +Skip this section if you had no issues with installation or running the example. + +**Missing header files during pytorch installation** + +If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: + +``` +~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory + #include + ^~~~~~~~ +compilation terminated. +[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o +ninja: build stopped: subcommand failed. +``` + +To get these header files, install `cudatoolkit-dev`: +``` +conda install -c conda-forge cudatoolkit-dev +``` + +Re-run the installation after this. + +**libdouble-conversion missing** +``` +~/torchrec/torchrec/inference/build$ ./example +./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory +``` + +If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: + +``` +sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ +libdouble-conversion.so.3 +``` + +**Two installations of glog** +``` +~/torchrec/torchrec/inference/build$ ./example +ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and +'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). +``` +The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was +previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. + +To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. + +**Undefined symbols with std::string or cxx11** + +If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely +that your dependencies were compiled with different ABI values. Re-compile your dependencies +and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. + +The ABI value of pytorch can be checked via: + +``` +import torch +torch._C._GLIBCXX_USE_CXX11_ABI +``` diff --git a/torchrec/inference/inference_legacy/__init__.py b/torchrec/inference/inference_legacy/__init__.py new file mode 100644 index 000000000..0f433eb47 --- /dev/null +++ b/torchrec/inference/inference_legacy/__init__.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[0, 21] + +"""Torchrec Inference + +Torchrec inference provides a Torch.Deploy based library for GPU inference. + +These includes: + - Model packaging in Python + - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. + - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. + - Model serving in C++ + - `BatchingQueue` is a generalized config-based request tensor batching implementation. + - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. + +We implemented an example of how to use this library with the TorchRec DLRM model. + - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. + - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. +""" + +from . import model_packager # noqa diff --git a/torchrec/inference/inference_legacy/client.py b/torchrec/inference/inference_legacy/client.py new file mode 100644 index 000000000..a3a9f2a83 --- /dev/null +++ b/torchrec/inference/inference_legacy/client.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import logging + +import grpc +import torch +from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch + + +def create_training_batch(args: argparse.Namespace) -> Batch: + return next( + iter( + DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), + batch_sampler=None, + pin_memory=False, + num_workers=0, + ) + ) + ) + + +def create_request( + batch: Batch, args: argparse.Namespace +) -> predictor_pb2.PredictionRequest: + def to_bytes(tensor: torch.Tensor) -> bytes: + return tensor.cpu().numpy().tobytes() + + float_features = predictor_pb2.FloatFeatures( + num_features=args.num_float_features, + values=to_bytes(batch.dense_features), + ) + + id_list_features = predictor_pb2.SparseFeatures( + num_features=args.num_id_list_features, + values=to_bytes(batch.sparse_features.values()), + lengths=to_bytes(batch.sparse_features.lengths()), + ) + + id_score_list_features = predictor_pb2.SparseFeatures(num_features=0) + embedding_features = predictor_pb2.FloatFeatures(num_features=0) + unary_features = predictor_pb2.SparseFeatures(num_features=0) + + return predictor_pb2.PredictionRequest( + batch_size=args.batch_size, + float_features=float_features, + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + embedding_features=embedding_features, + unary_features=unary_features, + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--ip", + type=str, + default="0.0.0.0", + ) + parser.add_argument( + "--port", + type=int, + default=50051, + ) + parser.add_argument( + "--num_float_features", + type=int, + default=13, + ) + parser.add_argument( + "--num_id_list_features", + type=int, + default=26, + ) + parser.add_argument( + "--num_id_score_list_features", + type=int, + default=0, + ) + parser.add_argument( + "--num_embedding_features", + type=int, + default=100000, + ) + parser.add_argument( + "--embedding_feature_dim", + type=int, + default=100, + ) + parser.add_argument( + "--batch_size", + type=int, + default=100, + ) + + args: argparse.Namespace = parser.parse_args() + + training_batch: Batch = create_training_batch(args) + request: predictor_pb2.PredictionRequest = create_request(training_batch, args) + + with grpc.insecure_channel(f"{args.ip}:{args.port}") as channel: + stub = predictor_pb2_grpc.PredictorStub(channel) + response = stub.Predict(request) + print("Response: ", response.predictions["default"].data) + +if __name__ == "__main__": + logging.basicConfig() diff --git a/torchrec/inference/docker/Dockerfile b/torchrec/inference/inference_legacy/docker/Dockerfile similarity index 100% rename from torchrec/inference/docker/Dockerfile rename to torchrec/inference/inference_legacy/docker/Dockerfile diff --git a/torchrec/inference/docs/inference.rst b/torchrec/inference/inference_legacy/docs/inference.rst similarity index 100% rename from torchrec/inference/docs/inference.rst rename to torchrec/inference/inference_legacy/docs/inference.rst diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h new file mode 100644 index 000000000..26e7987f9 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Internal Assertion failed: (" + std::string(#condition) + "), " + \ + "function " + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + \ + "Please report bug to TorchRec.\n" + message + "\n"); \ + } + +#define TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(condition) \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(#condition, "") + +#define TORCHREC_INTERNAL_ASSERT_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_INTERNAL_ASSERT(...) \ + TORCHREC_INTERNAL_ASSERT_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(__VA_ARGS__)); + +#define TORCHREC_CHECK_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Check failed: (" + std::string(#condition) + "), " + "function " + \ + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + message + "\n"); \ + } + +#define TORCHREC_CHECK_NO_MESSAGE(condition) \ + TORCHREC_CHECK_WITH_MESSAGE(#condition, "") + +#define TORCHREC_CHECK_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_CHECK(...) \ + TORCHREC_CHECK_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_CHECK_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_CHECK_NO_MESSAGE(__VA_ARGS__)); diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h new file mode 100644 index 000000000..2c5e8837c --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using LazyTensorRef = folly::detail::Lazy>&; + +// BatchingFunc should be responsible to move the output tensor to desired +// location using the device input. +class BatchingFunc { + public: + virtual ~BatchingFunc() = default; + + virtual std::unordered_map batch( + const std::string& /* featureName */, + const std::vector>& /* requests */, + const int64_t& /* totalNumBatch */, + LazyTensorRef /* batchOffsets */, + const c10::Device& /* device */, + LazyTensorRef /* batchItems */) = 0; +}; + +/** + * TorchRecBatchingFuncRegistry is used to register custom batching functions. + */ +C10_DECLARE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc); + +#define REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY(name, priority, ...) \ + C10_REGISTER_CLASS_WITH_PRIORITY( \ + TorchRecBatchingFuncRegistry, name, priority, __VA_ARGS__); + +#define REGISTER_TORCHREC_BATCHING_FUNC(name, ...) \ + REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY( \ + name, c10::REGISTRY_DEFAULT, __VA_ARGS__); + +std::unordered_map combineFloat( + const std::string& featureName, + const std::vector>& requests); + +std::unordered_map combineSparse( + const std::string& featureName, + const std::vector>& requests, + bool isWeighted); + +std::unordered_map combineEmbedding( + const std::string& featureName, + const std::vector>& requests); + +void moveIValueToDevice(c10::IValue& val, const c10::Device& device); + +std::unordered_map moveToDevice( + std::unordered_map combined, + const c10::Device& device); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h new file mode 100644 index 000000000..f012572b2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torchrec/inference/Batching.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResourceManager.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using BatchQueueCb = std::function)>; + +class BatchingQueue { + public: + struct Config { + std::chrono::milliseconds batchingInterval = std::chrono::milliseconds(10); + std::chrono::milliseconds queueTimeout = std::chrono::milliseconds(500); + int numExceptionThreads = 4; + int numMemPinnerThreads = 4; + int maxBatchSize = 2000; + // For feature name to BatchingFunc name. + const std::unordered_map batchingMetadata; + std::function eventCreationFn; + std::function warmupFn; + }; + + BatchingQueue(const BatchingQueue&) = delete; + BatchingQueue& operator=(const BatchingQueue&) = delete; + + BatchingQueue( + std::vector cbs, + const Config& config, + int worldSize, + std::unique_ptr observer, + std::shared_ptr resourceManager = nullptr); + ~BatchingQueue(); + + void add( + std::shared_ptr request, + folly::Promise> promise); + + void stop(); + + private: + struct QueryQueueEntry { + std::shared_ptr request; + RequestContext context; + std::chrono::time_point addedTime; + }; + + struct BatchingQueueEntry { + std::vector> requests; + std::vector contexts; + std::chrono::time_point addedTime; + }; + + void createBatch(); + + void pinMemory(int gpuIdx); + + void observeBatchCompletion(size_t batchSizeBytes, size_t numRequests); + + const Config config_; + + // Batching func name to batching func instance. + std::unordered_map> batchingFuncs_; + std::vector cbs_; + std::thread batchingThread_; + std::vector memPinnerThreads_; + std::unique_ptr rejectionExecutor_; + folly::Synchronized> requestQueue_; + std::vector>> + batchingQueues_; + std::atomic stopping_; + int worldSize_; + std::unique_ptr observer_; + std::shared_ptr resourceManager_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h new file mode 100644 index 000000000..5667d1ad2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace torchrec { + +// We have different error code defined for different kinds of exceptions in +// fblearner/sigrid predictor. (Code pointer: +// fblearner/predictor/if/prediction_service.thrift.) We define different +// exception type here so that in fblearner/sigrid predictor we can detect the +// exception type and return the corresponding error code to reflect the right +// info. +class TorchrecException : public std::runtime_error { + public: + explicit TorchrecException(const std::string& error) + : std::runtime_error(error) {} +}; + +// GPUOverloadException maps to +// PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT +class GPUOverloadException : public TorchrecException { + public: + explicit GPUOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// GPUExecutorOverloadException maps to +// PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT +class GPUExecutorOverloadException : public TorchrecException { + public: + explicit GPUExecutorOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// TorchDeployException maps to +// PredictorUserErrorCode::TORCH_DEPLOY_ERROR +class TorchDeployException : public TorchrecException { + public: + explicit TorchDeployException(const std::string& error) + : TorchrecException(error) {} +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h new file mode 100644 index 000000000..491acf48f --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/Exception.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { +template +void handleRequestException( + folly::Promise>& promise, + const std::string& msg) { + auto ex = folly::make_exception_wrapper(msg); + auto response = std::make_unique(); + response->exception = std::move(ex); + promise.setValue(std::move(response)); +} + +template +void handleBatchException( + std::vector& contexts, + const std::string& msg) { + for (auto& context : contexts) { + handleRequestException(context.promise, msg); + } +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h new file mode 100644 index 000000000..00c93668b --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#else +#include // @manual +#endif + +#include "torchrec/inference/BatchingQueue.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResultSplit.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" + +namespace torchrec { + +class GPUExecutor { + public: + // Used to interface with python's garbage collector + struct GCConfig { + bool optimizationEnabled = false; + size_t collectionFreq = 1000; + size_t statReportingFreq = 10000; + std::unique_ptr observer = + std::make_unique(); + std::map threadIdToNumForwards = std::map(); + }; + + GPUExecutor( + std::shared_ptr manager, + torch::deploy::ReplicatedObj model, + size_t rank, + size_t worldSize, + std::shared_ptr func, + std::chrono::milliseconds queueTimeout, + std::shared_ptr + observer, // shared_ptr because used in completion executor callback + std::function warmupFn = {}, + std::optional numThreadsPerGPU = std::nullopt, + std::unique_ptr gcConfig = std::make_unique()); + GPUExecutor(GPUExecutor&& executor) noexcept = default; + GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default; + ~GPUExecutor(); + + void callback(std::shared_ptr batch); + + void process(int idx); + + private: + // torch deploy + std::shared_ptr manager_; + torch::deploy::ReplicatedObj model_; + const size_t rank_; + const size_t worldSize_; + + folly::MPMCQueue> batches_; + std::vector processThreads_; + std::unique_ptr rejectionExecutor_; + std::unique_ptr completionExecutor_; + std::shared_ptr resultSplitFunc_; + const std::chrono::milliseconds queueTimeout_; + std::shared_ptr observer_; + std::function warmupFn_; + + std::mutex warmUpMutex_; + std::mutex warmUpAcquireSessionMutex_; + std::condition_variable warmUpCV_; + int warmUpCounter_{0}; + + size_t numThreadsPerGPU_; + + std::unique_ptr gcConfig_; + + void reportGCStats(c10::IValue stats); +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h new file mode 100644 index 000000000..de00aebac --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace torchrec { + +struct JaggedTensor { + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +struct KeyedJaggedTensor { + std::vector keys; + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h new file mode 100644 index 000000000..14ac3bceb --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +// Record generic timeseries stat with a key +class IDynamicTimeseriesObserver { + public: + virtual void addCount(uint32_t value, std::string key) = 0; + + virtual ~IDynamicTimeseriesObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyDynamicTimeseriesObserver : public IDynamicTimeseriesObserver { + public: + void addCount(uint32_t /* value */, std::string /* key */) override {} +}; + +class IBatchingQueueObserver { + public: + // Record the amount of time an entry of PredictionRequests + // in the batching queue waits before they are read and allocated + // onto a GPU device. + virtual void recordBatchingQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes for a batching function + // to execute. + virtual void recordBatchingFuncLatency( + uint32_t value, + std::string batchingFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes to create a batch of + // requests. + virtual void recordBatchCreationLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of batching queue timeouts experienced. + virtual void addBatchingQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of times a GPU could not be chosen + // for allocation. + virtual void addGPUBusyCount(uint32_t value) = 0; + + // Increment the number of requests entering the batching queue. + virtual void addRequestsCount(uint32_t value) = 0; + + // Increment the number of bytes of tensors moved to cuda. + virtual void addBytesMovedToGPUCount(uint32_t value) = 0; + + // Increment the number of batches processed by the batching + // queue (moved onto the GPU executor). + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + // Increment the number of requests processed by the batching + // queue (moved onto the GPU executor). + virtual void addRequestsProcessedCount(uint32_t value) = 0; + + // The obervations that should be made when a batch is completed. + virtual void observeBatchCompletion( + size_t batchSizeBytes, + size_t numRequests) { + addBytesMovedToGPUCount(batchSizeBytes); + addBatchesProcessedCount(1); + addRequestsProcessedCount(numRequests); + } + + virtual ~IBatchingQueueObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyBatchingQueueObserver : public IBatchingQueueObserver { + public: + void recordBatchingQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchingFuncLatency( + uint32_t /* value */, + std::string /* batchingFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchCreationLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addBatchingQueueTimeoutCount(uint32_t /* value */) override {} + + void addGPUBusyCount(uint32_t /* value */) override {} + + void addRequestsCount(uint32_t /* value */) override {} + + void addBytesMovedToGPUCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} + + void addRequestsProcessedCount(uint32_t /* value */) override {} +}; + +class IGPUExecutorObserver { + public: + // Record the amount of time a batch spends in the GPU Executor + // queue. + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of prediction (forward call, H2D). + virtual void recordPredictionLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of device to host transfer facilitated + // by result split function. + virtual void recordDeviceToHostLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of splitting the result. + virtual void recordResultSplitLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency from enqueue to completion. + virtual void recordTotalLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of GPUExecutor queue timeouts. + virtual void addQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of predict exceptions. + virtual void addPredictionExceptionCount(uint32_t value) = 0; + + // Increment the number of batches successfully processed. + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + virtual ~IGPUExecutorObserver() {} +}; + +class ISingleGPUExecutorObserver { + public: + virtual void addRequestsCount(uint32_t value) = 0; + virtual void addRequestProcessingExceptionCount(uint32_t value) = 0; + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) = 0; + + virtual void recordRequestProcessingLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + virtual ~ISingleGPUExecutorObserver() = default; +}; + +class EmptySingleGPUExecutorObserver : public ISingleGPUExecutorObserver { + void addRequestsCount(uint32_t) override {} + void addRequestProcessingExceptionCount(uint32_t) override {} + void recordQueueLatency( + uint32_t, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) override {} + + void recordRequestProcessingLatency( + uint32_t, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) override {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyGPUExecutorObserver : public IGPUExecutorObserver { + public: + void recordQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordPredictionLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordDeviceToHostLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordResultSplitLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordTotalLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addQueueTimeoutCount(uint32_t /* value */) override {} + + void addPredictionExceptionCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} +}; + +class IResourceManagerObserver { + public: + // Add the number of requests in flight for a gpu + virtual void addOutstandingRequestsCount(uint32_t value, int gpuIdx) = 0; + + // Add the most in flight requests on a gpu ever + virtual void addAllTimeHighOutstandingCount(uint32_t value, int gpuIdx) = 0; + + // Record the latency for finding a device + virtual void addWaitingForDeviceLatency( + uint32_t value, + int gpuIdx, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Recording all stats related to resource manager at once. + virtual void recordAllStats( + uint32_t outstandingRequests, + uint32_t allTimeHighOutstanding, + uint32_t waitedForMs, + int gpuIdx) { + addOutstandingRequestsCount(outstandingRequests, gpuIdx); + addAllTimeHighOutstandingCount(allTimeHighOutstanding, gpuIdx); + addWaitingForDeviceLatency(waitedForMs, gpuIdx); + } + + virtual ~IResourceManagerObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyResourceManagerObserver : public IResourceManagerObserver { + public: + void addOutstandingRequestsCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addAllTimeHighOutstandingCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addWaitingForDeviceLatency( + uint32_t /* value */, + int /* gpuIdx */, + std::chrono::steady_clock::time_point /* now */) override {} +}; + +// Helper for determining how much time has elapsed in milliseconds since a +// given time point. +inline std::chrono::milliseconds getTimeElapsedMS( + std::chrono::steady_clock::time_point startTime) { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime); +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h new file mode 100644 index 000000000..d3dd1ea18 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "torchrec/inference/Observer.h" + +namespace torchrec { + +/** + * ResourceManager can be used to limit in-flight batches + * allocated onto GPUs to prevent OOMing. + */ +class ResourceManager { + public: + ResourceManager( + int worldSize, + size_t maxOutstandingBatches, + int logFrequency = 100, + std::unique_ptr observer = + std::make_unique()); + + // Returns whether batches can be allocated onto a device based on + // slack provided (ms) and maxOutstandingBatches_). + bool occupyDevice(int gpuIdx, std::chrono::milliseconds slack); + + void release(int gpuIdx); + + private: + folly::small_vector gpuToOutstandingBatches_; + // Helpful for tuning + folly::small_vector allTimeHigh_; + const size_t maxOutstandingBatches_; + const int logFrequency_; + // Align as 64B to avoid false sharing + alignas(64) std::mutex mu_; + std::unique_ptr observer_; +}; + +class ResourceManagerGuard { + public: + ResourceManagerGuard( + std::weak_ptr resourceManager, + int gpuIdx) + : resourceManager_(std::move(resourceManager)), gpuIdx_(gpuIdx) {} + + ~ResourceManagerGuard() { + std::shared_ptr rm = resourceManager_.lock(); + if (rm != nullptr) { + rm->release(gpuIdx_); + } + } + + private: + std::weak_ptr resourceManager_; + const int gpuIdx_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h new file mode 100644 index 000000000..2c3ef2463 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +class ResultSplitFunc { + public: + virtual ~ResultSplitFunc() = default; + + virtual std::string name() = 0; + + virtual c10::IValue splitResult( + c10::IValue /* result */, + size_t /* nOffset */, + size_t /* nLength */, + size_t /* nTotalLength */) = 0; + + virtual c10::IValue moveToHost(c10::IValue /* result */) = 0; +}; + +/** + * TorchRecResultSplitFuncRegistry is used to register custom result split + * functions. + */ +C10_DECLARE_REGISTRY(TorchRecResultSplitFuncRegistry, ResultSplitFunc); + +#define REGISTER_TORCHREC_RESULTSPLIT_FUNC(name, ...) \ + C10_REGISTER_CLASS(TorchRecResultSplitFuncRegistry, name, __VA_ARGS__); + +c10::IValue splitDictOfTensor( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue splitDictOfTensors( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue +splitDictWithMaskTensor(c10::IValue result, size_t nOffset, size_t nLength); + +class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { + public: + virtual std::string name() override; + + virtual c10::IValue splitResult( + c10::IValue result, + size_t offset, + size_t length, + size_t /* nTotalLength */) override; + + c10::IValue moveToHost(c10::IValue result) override; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h new file mode 100644 index 000000000..478773768 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace torchrec { + +struct ShardMetadata { + std::vector shard_offsets; + std::vector shard_lengths; + + bool operator==(const ShardMetadata& other) const { + return shard_offsets == other.shard_offsets && + shard_lengths == other.shard_lengths; + } +}; + +struct Shard { + ShardMetadata metadata; + at::Tensor tensor; +}; + +struct ShardedTensorMetadata { + std::vector shards_metadata; +}; + +struct ShardedTensor { + std::vector sizes; + std::vector local_shards; + ShardedTensorMetadata metadata; +}; + +struct ReplicatedTensor { + ShardedTensor local_replica; + int64_t local_replica_id; + int64_t replica_count; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h new file mode 100644 index 000000000..9da63d7c2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +class SingleGPUExecutor { + constexpr static const size_t kQUEUE_CAPACITY = 10000; + + public: + struct ExecInfo { + size_t gpuIdx; + size_t interpIdx; + torch::deploy::ReplicatedObj model; + }; + using ExecInfos = std::vector; + + SingleGPUExecutor( + std::shared_ptr manager, + ExecInfos execInfos, + size_t numGpu, + std::shared_ptr observer = + std::make_shared(), + c10::Device resultDevice = c10::kCPU, + size_t numProcessThreads = 1u, + bool useHighPriCudaStream = false); + + // Moveable only + SingleGPUExecutor(SingleGPUExecutor&& executor) noexcept = default; + SingleGPUExecutor& operator=(SingleGPUExecutor&& executor) noexcept = default; + ~SingleGPUExecutor(); + + void schedule(std::shared_ptr request); + + private: + void process(); + + std::shared_ptr manager_; + const ExecInfos execInfos_; + const size_t numGpu_; + const size_t numProcessThreads_; + const bool useHighPriCudaStream_; + const c10::Device resultDevice_; + std::shared_ptr observer_; + folly::MPMCQueue> requests_; + + std::unique_ptr processExecutor_; + std::unique_ptr completionExecutor_; + std::atomic roundRobinExecInfoNextIdx_; +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h new file mode 100644 index 000000000..1150e2c23 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +std::shared_ptr createRequest(at::Tensor denseTensor); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); + +JaggedTensor createJaggedTensor(const std::vector>& input); + +c10::List createIValueList( + const std::vector>& input); + +at::Tensor createEmbeddingTensor( + const std::vector>& input); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h new file mode 100644 index 000000000..0a09d1f35 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "torchrec/inference/ResourceManager.h" + +namespace torchrec { + +struct SparseFeatures { + uint32_t num_features; + // int32: T x B + folly::IOBuf lengths; + // T x B x L (jagged) + folly::IOBuf values; + // float16 + folly::IOBuf weights; +}; + +struct FloatFeatures { + uint32_t num_features; + // shape: {B} + folly::IOBuf values; +}; + +// TODO: Change the input format to torch::IValue. +// Currently only dense batching function support IValue. +using Feature = std::variant; + +struct PredictionRequest { + uint32_t batch_size; + std::unordered_map features; +}; + +struct PredictionResponse { + uint32_t batchSize; + c10::IValue predictions; + // If set, the result is an exception. + std::optional exception; +}; + +struct RequestContext { + uint32_t batchSize; + folly::Promise> promise; + // folly request context for request tracking in crochet + std::shared_ptr follyRequestContext; +}; + +using PredictionException = std::runtime_error; + +using Event = std:: + unique_ptr>; + +struct BatchingMetadata { + std::string type; + std::string device; + folly::F14FastSet pinned; +}; + +// noncopyable because we only want to move PredictionBatch around +// as it holds a reference to ResourceManagerGuard. We wouldn't want +// to inadvertently increase the reference count to ResourceManagerGuard +// with copies of this struct. +struct PredictionBatch : public boost::noncopyable { + std::string methodName; + std::vector args; + + size_t batchSize; + + c10::impl::GenericDict forwardArgs; + + std::vector contexts; + + std::unique_ptr resourceManagerGuard = nullptr; + + std::chrono::time_point enqueueTime = + std::chrono::steady_clock::now(); + + Event event; + + // Need a constructor to use make_shared/unique with + // noncopyable struct and not trigger copy-constructor. + PredictionBatch( + size_t bs, + c10::impl::GenericDict fa, + std::vector ctxs, + std::unique_ptr rmg = nullptr) + : batchSize(bs), + forwardArgs(std::move(fa)), + contexts(std::move(ctxs)), + resourceManagerGuard(std::move(rmg)) {} + + PredictionBatch( + std::string methodNameArg, + std::vector argsArg, + folly::Promise> promise) + : methodName(std::move(methodNameArg)), + args(std::move(argsArg)), + forwardArgs( + c10::impl::GenericDict(at::StringType::get(), at::AnyType::get())) { + contexts.push_back(RequestContext{1u, std::move(promise)}); + } + + size_t sizeOfIValue(const c10::IValue& val) const { + size_t size = 0; + if (val.isTensor()) { + size += val.toTensor().storage().nbytes(); + } else if (val.isList()) { + for (const auto& v : val.toListRef()) { + size += sizeOfIValue(v); + } + } + return size; + } + + inline size_t size() const { + size_t size = 0; + for (const auto& iter : forwardArgs) { + size += sizeOfIValue(iter.value()); + } + return size; + } +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h new file mode 100644 index 000000000..74a2f20ff --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "torchrec/inference/Types.h" + +namespace torchrec { + +// Returns whether sparse features (KeyedJaggedTensor) are valid. +// Currently validates: +// 1. Whether sum(lengths) == size(values) +// 2. Whether there are negative values in lengths +// 3. If weights is present, whether sum(lengths) == size(weights) +bool validateSparseFeatures( + at::Tensor& values, + at::Tensor& lengths, + std::optional maybeWeights = std::nullopt); + +// Returns whether dense features are valid. +// Currently validates: +// 1. Whether the size of values is divisable by batch size (request level) +bool validateDenseFeatures(at::Tensor& values, size_t batchSize); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/model_packager.py b/torchrec/inference/inference_legacy/model_packager.py new file mode 100644 index 000000000..9957b3373 --- /dev/null +++ b/torchrec/inference/inference_legacy/model_packager.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Type, TypeVar, Union + +import torch +from torch.package import PackageExporter +from torchrec.inference.modules import PredictFactory + +LOADER_MODULE = "__module_loader" +LOADER_FACTORY = "MODULE_FACTORY" +LOADER_CODE = f""" +import %PACKAGE% + +{LOADER_FACTORY}=%PACKAGE%.%CLASS% +""" +CONFIG_MODULE = "__configs" + +T = TypeVar("T") + + +try: + # pyre-fixme[21]: Could not find module `torch_package_importer`. + import torch_package_importer # @manual +except ImportError: + pass + + +def load_config_text(name: str) -> str: + return torch_package_importer.load_text("__configs", name) + + +def load_pickle_config(name: str, clazz: Type[T]) -> T: + loaded_obj = torch_package_importer.load_pickle("__configs", name) + assert isinstance( + loaded_obj, clazz + ), f"The loaded config {type(loaded_obj)} is not of type {clazz}" + return loaded_obj + + +class PredictFactoryPackager: + @classmethod + @abc.abstractclassmethod + def set_extern_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + @abc.abstractclassmethod + def set_mocked_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + def save_predict_factory( + cls, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + output: Union[str, Path, BinaryIO], + extra_files: Dict[str, Union[str, bytes]], + loader_code: str = LOADER_CODE, + package_importer: Union[ + torch.package.Importer, List[torch.package.Importer] + ] = torch.package.sys_importer, + ) -> None: + with PackageExporter(output, importer=package_importer) as pe: + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_extern_modules(pe) + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_mocked_modules(pe) + pe.extern(["sys"]) + pe.intern("**") + for k, v in extra_files.items(): + if isinstance(v, str): + pe.save_text("extra_files", k, v) + elif isinstance(v, bytes): + pe.save_binary("extra_files", k, v) + else: + raise ValueError(f"Unsupported type {type(v)}") + cls._save_predict_factory( + pe, predict_factory, configs, loader_code=loader_code + ) + + @classmethod + def _save_predict_factory( + cls, + pe: PackageExporter, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + loader_code: str = LOADER_CODE, + ) -> None: + # If predict_factory is coming from a torch package, + # __module__ would have prefix. + # To save such predict factory, we need to remove + # the prefix. + package_name = predict_factory.__module__ + if package_name.startswith(" predictions = 1; +} + +// The predictor service definition. Synchronous for now. +service Predictor { + rpc Predict(PredictionRequest) returns (PredictionResponse) {} +} diff --git a/torchrec/inference/inference_legacy/server.cpp b/torchrec/inference/inference_legacy/server.cpp new file mode 100644 index 000000000..f7695cdcf --- /dev/null +++ b/torchrec/inference/inference_legacy/server.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#include +#else +#include +#include +#endif + +#include + +#include "torchrec/inference/GPUExecutor.h" +#include "torchrec/inference/predictor.grpc.pb.h" +#include "torchrec/inference/predictor.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using predictor::FloatVec; +using predictor::PredictionRequest; +using predictor::PredictionResponse; +using predictor::Predictor; + +DEFINE_int32(n_interp_per_gpu, 1, ""); +DEFINE_int32(n_gpu, 1, ""); +DEFINE_string(package_path, "", ""); + +DEFINE_int32(batching_interval, 10, ""); +DEFINE_int32(queue_timeout, 500, ""); + +DEFINE_int32(num_exception_threads, 4, ""); +DEFINE_int32(num_mem_pinner_threads, 4, ""); +DEFINE_int32(max_batch_size, 2048, ""); +DEFINE_int32(gpu_executor_queue_timeout, 50, ""); + +DEFINE_string(server_address, "0.0.0.0", ""); +DEFINE_string(server_port, "50051", ""); + +DEFINE_string( + python_packages_path, + "", + "Used to load the packages that you 'extern' with torch.package"); + +namespace { + +std::unique_ptr toTorchRecRequest( + const PredictionRequest* request) { + auto torchRecRequest = std::make_unique(); + torchRecRequest->batch_size = request->batch_size(); + + // Client sends a request with serialized tensor to bytes. + // Byte string is converted to folly::iobuf for torchrec request. + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->float_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["float_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["id_list_features"] = std::move(sparseFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_score_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + auto encoded_weights = feature.weights(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + sparseFeature.weights = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_weights.data(), + encoded_weights.size()}; + + torchRecRequest->features["id_score_list_features"] = + std::move(sparseFeature); + } + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->embedding_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["embedding_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->unary_features(); + auto encoded_lengths = feature.lengths(); + auto encoded_values = feature.values(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["unary_features"] = std::move(sparseFeature); + } + + return torchRecRequest; +} + +// Logic behind the server's behavior. +class PredictorServiceHandler final : public Predictor::Service { + public: + explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) + : queue_(queue) {} + + Status Predict( + grpc::ServerContext* context, + const PredictionRequest* request, + PredictionResponse* reply) override { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + queue_.add(toTorchRecRequest(request), std::move(promise)); + auto torchRecResponse = + std::move(future).get(); // blocking, TODO: Write async server + auto predictions = reply->mutable_predictions(); + + // Convert ivalue to map, TODO: find out if protobuf + // can support custom types (folly::iobuf), so we can avoid this overhead. + for (const auto& item : torchRecResponse->predictions.toGenericDict()) { + auto tensor = item.value().toTensor(); + FloatVec fv; + fv.mutable_data()->Add( + tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); + (*predictions)[item.key().toStringRef()] = fv; + } + + return Status::OK; + } + + private: + torchrec::BatchingQueue& queue_; +}; + +} // namespace + +int main(int argc, char* argv[]) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + LOG(INFO) << "Creating GPU executors"; + + // store the executors and interpreter managers + std::vector> executors; + std::vector models; + std::vector batchQueueCbs; + std::unordered_map batchingMetadataMap; + + std::shared_ptr env = + std::make_shared( + FLAGS_python_packages_path); + + auto manager = std::make_shared( + FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); + { + torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); + auto I = package.acquireSession(); + auto imported = I.self.attr("import_module")({"__module_loader"}); + auto factoryType = imported.attr("MODULE_FACTORY"); + auto factory = factoryType.attr("__new__")({factoryType}); + factoryType.attr("__init__")({factory}); + + // Process forward metadata. + try { + auto batchingMetadataJsonStr = + factory.attr("batching_metadata_json")(at::ArrayRef()) + .toIValue() + .toString() + ->string(); + auto dynamic = folly::parseJson(batchingMetadataJsonStr); + CHECK(dynamic.isObject()); + for (auto it : dynamic.items()) { + torchrec::BatchingMetadata metadata; + metadata.type = it.second["type"].asString(); + metadata.device = it.second["device"].asString(); + batchingMetadataMap[it.first.asString()] = std::move(metadata); + } + } catch (...) { + auto batchingMetadata = + factory.attr("batching_metadata")(at::ArrayRef()) + .toIValue(); + for (const auto& iter : batchingMetadata.toGenericDict()) { + torchrec::BatchingMetadata metadata; + metadata.type = iter.value().toStringRef(); + metadata.device = "cuda"; + batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); + } + } + + // Process result metadata. + auto resultMetadata = + factory.attr("result_metadata")(at::ArrayRef()) + .toIValue() + .toStringRef(); + std::shared_ptr resultSplitFunc = + torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); + + LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; + auto dmp = factory.attr("create_predict_module") + .callKwargs({{"world_size", FLAGS_n_gpu}}); + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto device = I.self.attr("import_module")({"torch"}).attr("device")( + {"cuda", rank}); + auto m = dmp.attr("copy")({device.toIValue()}); + models.push_back(I.createMovable(m)); + } + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto executor = std::make_unique( + manager, + std::move(models[rank]), + rank, + FLAGS_n_gpu, + resultSplitFunc, + std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); + executors.push_back(std::move(executor)); + batchQueueCbs.push_back( + [&, rank](std::shared_ptr batch) { + executors[rank]->callback(std::move(batch)); + }); + } + } + + torchrec::BatchingQueue queue( + batchQueueCbs, + torchrec::BatchingQueue::Config{ + .batchingInterval = + std::chrono::milliseconds(FLAGS_batching_interval), + .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), + .numExceptionThreads = FLAGS_num_exception_threads, + .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, + .maxBatchSize = FLAGS_max_batch_size, + .batchingMetadata = std::move(batchingMetadataMap), + }, + FLAGS_n_gpu); + + // create the server + std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); + auto service = PredictorServiceHandler(queue); + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + grpc::ServerBuilder builder; + + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + LOG(INFO) << "Server listening on " << server_address; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); + + LOG(INFO) << "Shutting down server"; + return 0; +} diff --git a/torchrec/inference/src/Batching.cpp b/torchrec/inference/inference_legacy/src/Batching.cpp similarity index 70% rename from torchrec/inference/src/Batching.cpp rename to torchrec/inference/inference_legacy/src/Batching.cpp index 6fe2d4888..667539080 100644 --- a/torchrec/inference/src/Batching.cpp +++ b/torchrec/inference/inference_legacy/src/Batching.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include "torchrec/inference/Batching.h" +#include "torchrec/inference/Batching.h" // @manual #include #include @@ -14,24 +14,40 @@ #include #include "ATen/Functions.h" +#include "ATen/core/List.h" +#include "ATen/core/ivalue.h" #include "torchrec/inference/Types.h" namespace torchrec { -std::unordered_map moveToDevice( - std::unordered_map combined, +void moveIValueToDevice(c10::IValue& val, const c10::Device& device) { + if (val.isTensor()) { + if (val.toTensor().device() != device) { + val = val.toTensor().to(device, /* non_blocking */ true); + } + } else if (val.isList()) { + for (auto v : val.toListRef()) { + moveIValueToDevice(v, device); + } + } else { + LOG(WARNING) + << "moveIValueToDevice only supports types c10::List and at::Tensor but received type " + << val.type().get()->repr_str(); + } +} + +std::unordered_map moveToDevice( + std::unordered_map combined, const c10::Device& device) { for (auto& [k, v] : combined) { - if (v.device() != device) { - v = v.to(device, /* non_blocking */ true); - } + moveIValueToDevice(v, device); } return combined; } C10_DEFINE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc); -std::unordered_map combineFloat( +std::unordered_map combineFloat( const std::string& featureName, const std::vector>& requests) { // Compute combined batch size. @@ -107,7 +123,7 @@ std::unordered_map combineFloat( return {{featureName, std::move(combined)}}; } -std::unordered_map combineSparse( +std::unordered_map combineSparse( const std::string& featureName, const std::vector>& requests, bool isWeighted) { @@ -201,7 +217,7 @@ std::unordered_map combineSparse( } } - std::unordered_map ret = { + std::unordered_map ret = { {featureName + ".values", std::move(values)}, {featureName + ".lengths", std::move(lengths)}, }; @@ -211,68 +227,120 @@ std::unordered_map combineSparse( return ret; } -std::unordered_map combineEmbedding( +std::unordered_map combineEmbedding( const std::string& featureName, const std::vector>& requests) { // Compute combined batch size. long combinedBatchSize = 0; long numFeatures = 0; long dimension = 0; + + // If input is IValue then we expect a List[Tensor] of length numFeatures + // Each element of this list is a batch of features with size (batchSize x + // dimension) + auto* maybeIValuePtr = + std::get_if(&requests.front()->features[featureName]); + for (const auto& request : requests) { - const auto& features = - std::get(request->features[featureName]); - const auto nf = features.num_features; - const auto dataSize = features.values.computeChainDataLength(); - if (nf != 0 && dimension == 0) { - dimension = dataSize / request->batch_size / sizeof(float) / nf; - } - if (nf * request->batch_size * dimension * sizeof(float) != dataSize) { - throw std::invalid_argument("Invalid embedding features"); - } - if (nf > 0) { - combinedBatchSize += request->batch_size; - if (numFeatures > 0) { - if (numFeatures != nf) { - throw std::invalid_argument("Different number of embedding features"); - } + if (maybeIValuePtr != nullptr) { + auto ival = std::get(request->features[featureName]) + .toTensorVector(); + auto nf = ival.size(); + if (nf == 0) { + continue; + } + if (numFeatures > 0 && nf > 0 && numFeatures != nf) { + throw std::invalid_argument("Different number of embedding features"); } numFeatures = nf; + combinedBatchSize += ival.at(0).size(0); + } else { + const auto& features = + std::get(request->features[featureName]); + const auto nf = features.num_features; + const auto dataSize = features.values.computeChainDataLength(); + if (nf != 0 && dimension == 0) { + dimension = dataSize / request->batch_size / sizeof(float) / nf; + } + if (nf * request->batch_size * dimension * sizeof(float) != dataSize) { + throw std::invalid_argument("Invalid embedding features"); + } + if (nf > 0) { + combinedBatchSize += request->batch_size; + if (numFeatures > 0) { + if (numFeatures != nf) { + throw std::invalid_argument( + "Different number of embedding features"); + } + } + numFeatures = nf; + } } } if (numFeatures == 0) { - return {{featureName, at::empty(0)}}; + return {{featureName, c10::List()}}; + } + + std::vector cursors; + if (maybeIValuePtr != nullptr) { + std::vector> featureBatches(numFeatures); + for (const auto& request : requests) { + auto ival = std::get(request->features[featureName]) + .toTensorVector(); + if (ival.size() == 0) { + continue; + } + for (int i = 0; i < numFeatures; ++i) { + auto featureBatch = ival.at(i); + if (featureBatch.dim() == 1) { + featureBatch = featureBatch.unsqueeze(1); + } + featureBatches.at(i).push_back(featureBatch); + } + } + + c10::List retList; + for (const auto& fb : featureBatches) { + retList.push_back(at::cat(fb)); + } + return {{featureName, std::move(retList)}}; + } + + for (const auto& request : requests) { + const auto& features = + std::get(request->features[featureName]); + cursors.emplace_back(&features.values); } // Create output tensor. const auto options = at::TensorOptions(at::kCPU).dtype(at::kFloat).pinned_memory(true); auto combined = - at::empty({numFeatures, combinedBatchSize, dimension}, options); + at::empty({combinedBatchSize, numFeatures, dimension}, options); // Copy tensor data. auto combinedRange = folly::MutableByteRange( reinterpret_cast(combined.data_ptr()), combined.storage().nbytes()); - std::vector cursors; - for (const auto& request : requests) { - const auto& features = - std::get(request->features[featureName]); - cursors.emplace_back(&features.values); + + for (const auto&& it : folly::enumerate(cursors)) { + auto len = requests[it.index]->batch_size * dimension * numFeatures * + sizeof(float); + it.element.pull(combinedRange.data(), len); + combinedRange.advance(len); } - for (int i = 0; i < numFeatures; ++i) { - for (const auto&& it : folly::enumerate(cursors)) { - auto len = requests[it.index]->batch_size * dimension * sizeof(float); - it.element.pull(combinedRange.data(), len); - combinedRange.advance(len); - } + + auto listFeatureBatches = c10::List(); + for (auto& tensor : combined.transpose(0, 1).split(1)) { + listFeatureBatches.push_back(tensor.squeeze(0)); } - return {{featureName, std::move(combined)}}; + return {{featureName, std::move(listFeatureBatches)}}; } class FloatBatchingFunc : public BatchingFunc { public: - std::unordered_map batch( + std::unordered_map batch( const std::string& featureName, const std::vector>& requests, const int64_t& /* totalNumBatch */, @@ -285,7 +353,7 @@ class FloatBatchingFunc : public BatchingFunc { class SparseBatchingFunc : public BatchingFunc { public: - std::unordered_map batch( + std::unordered_map batch( const std::string& featureName, const std::vector>& requests, const int64_t& /* totalNumBatch */, @@ -299,7 +367,7 @@ class SparseBatchingFunc : public BatchingFunc { class WeightedSparseBatchingFunc : public BatchingFunc { public: - std::unordered_map batch( + std::unordered_map batch( const std::string& featureName, const std::vector>& requests, const int64_t& /* totalNumBatch */, @@ -313,7 +381,7 @@ class WeightedSparseBatchingFunc : public BatchingFunc { class EmbeddingBatchingFunc : public BatchingFunc { public: - std::unordered_map batch( + std::unordered_map batch( const std::string& featureName, const std::vector>& requests, const int64_t& /* totalNumBatch */, diff --git a/torchrec/inference/src/BatchingQueue.cpp b/torchrec/inference/inference_legacy/src/BatchingQueue.cpp similarity index 89% rename from torchrec/inference/src/BatchingQueue.cpp rename to torchrec/inference/inference_legacy/src/BatchingQueue.cpp index 953ef49e1..a26dcadb3 100644 --- a/torchrec/inference/src/BatchingQueue.cpp +++ b/torchrec/inference/inference_legacy/src/BatchingQueue.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include "torchrec/inference/BatchingQueue.h" +#include "torchrec/inference/BatchingQueue.h" // @manual #include #include @@ -34,7 +34,7 @@ #include #include -#include "torchrec/inference/Exception.h" +#include "torchrec/inference/ExceptionHandler.h" #include "torchrec/inference/Observer.h" #include "torchrec/inference/ResourceManager.h" #include "torchrec/inference/Types.h" @@ -101,7 +101,10 @@ void BatchingQueue::add( const auto batchSize = request->batch_size; queue.push(QueryQueueEntry{ std::move(request), - RequestContext{batchSize, std::move(promise)}, + RequestContext{ + batchSize, + std::move(promise), + folly::RequestContext::saveContext()}, addedTime}); }); } @@ -133,7 +136,8 @@ void BatchingQueue::createBatch() { observer_->addBatchingQueueTimeoutCount(1); rejectionExecutor_->add( [promise = std::move(front.context.promise)]() mutable { - handleRequestException(promise, "Batching queue timeout"); + handleRequestException( + promise, "Batching queue timeout"); }); queue.pop(); continue; @@ -150,6 +154,7 @@ void BatchingQueue::createBatch() { } auto& context = contexts.emplace_back(std::move(front.context)); + folly::RequestContext::setContext(context.follyRequestContext); requests.push_back(std::move(front.request)); batchSize += requests.back()->batch_size; queue.pop(); @@ -178,6 +183,8 @@ void BatchingQueue::createBatch() { contexts.clear(); } + folly::RequestContext::setContext(nullptr); + if (!full) { /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(1)); @@ -207,6 +214,9 @@ void BatchingQueue::pinMemory(int gpuIdx) { if (!requests.empty() || !contexts.empty()) { RECORD_USER_SCOPE("PinMemory"); + if (!contexts.empty()) { + folly::RequestContext::setContext(contexts[0].follyRequestContext); + } // Combine data. size_t combinedBatchSize = 0; for (auto i : c10::irange(requests.size())) { @@ -245,9 +255,10 @@ void BatchingQueue::pinMemory(int gpuIdx) { return batchItems; }); - c10::Dict forwardArgs; + c10::impl::GenericDict forwardArgs( + at::StringType::get(), at::AnyType::get()); auto combineForwardArgs = - [&](std::unordered_map map) { + [&](std::unordered_map map) { for (auto& [key, value] : map) { CHECK(!forwardArgs.contains(key)); forwardArgs.insert(key, std::move(value)); @@ -269,7 +280,7 @@ void BatchingQueue::pinMemory(int gpuIdx) { // A device could not be chosen in time. Time out. observer_->addGPUBusyCount(1); rejectionExecutor_->add([ctxs = std::move(contexts)]() mutable { - handleBatchException( + handleBatchException( ctxs, "All GPUs are busy. Batching queue timeout."); }); continue; @@ -323,9 +334,18 @@ void BatchingQueue::pinMemory(int gpuIdx) { observer_->observeBatchCompletion(batch->size(), batch->batchSize); cbs_[gpuIdx](batch); + + // unset request tracking + folly::RequestContext::setContext(nullptr); } } catch (const std::exception& ex) { - LOG(FATAL) << "Error batching requests, ex: " << folly::exceptionStr(ex); + LOG(ERROR) << "Error batching requests, ex: " << folly::exceptionStr(ex); + for (auto& ctx : contexts) { + rejectionExecutor_->add([promise = std::move(ctx.promise)]() mutable { + handleRequestException( + promise, "Error during batching requests"); + }); + } } } } diff --git a/torchrec/inference/inference_legacy/src/Executer2.cpp b/torchrec/inference/inference_legacy/src/Executer2.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/inference/src/GPUExecutor.cpp b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp similarity index 60% rename from torchrec/inference/src/GPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/GPUExecutor.cpp index e7e5531f6..8178ed3f0 100644 --- a/torchrec/inference/src/GPUExecutor.cpp +++ b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -20,10 +21,11 @@ #include #include #include +#include #include #include #include -#include +#include // @manual // remove this after we switch over to multipy externally for torchrec #ifdef FBCODE_CAFFE2 @@ -34,7 +36,7 @@ #include "ATen/cuda/CUDAEvent.h" #include "torchrec/inference/BatchingQueue.h" -#include "torchrec/inference/Exception.h" +#include "torchrec/inference/ExceptionHandler.h" #include "torchrec/inference/Observer.h" #include "torchrec/inference/Types.h" @@ -75,7 +77,8 @@ GPUExecutor::GPUExecutor( std::chrono::milliseconds queueTimeout, std::shared_ptr observer, std::function warmupFn, - c10::optional numThreadsPerGPU) + std::optional numThreadsPerGPU, + std::unique_ptr gcConfig) : manager_(manager), model_(std::move(model)), rank_(rank), @@ -88,24 +91,52 @@ GPUExecutor::GPUExecutor( numThreadsPerGPU_( numThreadsPerGPU.has_value() ? *numThreadsPerGPU - : manager_->allInstances().size() / worldSize_) { + : manager_->allInstances().size() / worldSize_), + gcConfig_(std::move(gcConfig)) { CHECK(observer_ != nullptr); + CHECK(gcConfig_ != nullptr); + at::cuda::CUDAGuard guard(rank_); rejectionExecutor_ = std::make_unique(2 * numThreadsPerGPU_); for (int i = 0; i < numThreadsPerGPU_; ++i) { + auto threadId = rank * numThreadsPerGPU_ + i; + LOG(INFO) << "Starting Thread " << i << " for Model Shard Rank " << rank_ - << ", as Global thread: " << rank * numThreadsPerGPU_ + i; - processThreads_.emplace_back( - [this, rank, threadsPerGPU = numThreadsPerGPU_, i] { - if (FLAGS_emit_nsys_nvtx) { - enable_nvtx_tracing(); - } - process(rank * threadsPerGPU + i); - }); + << ", as Global thread: " << threadId; + + if (gcConfig_->optimizationEnabled) { + gcConfig_->threadIdToNumForwards[threadId] = 0; + // Freeze all python objects in each interpreter + auto model = + model_.acquireSession(&manager_->allInstances().at(threadId)); + model.global("gc", "freeze")(at::ArrayRef()); + } + + processThreads_.emplace_back([this, threadId] { + if (FLAGS_emit_nsys_nvtx) { + enable_nvtx_tracing(); + } + process(threadId); + }); + } + + // Acquire sessionn in main thread for interpreter 0 to avoid deadlock in + // torch deploy. + { + std::lock_guard lock(warmUpAcquireSessionMutex_); + LOG(INFO) + << " - pre-acquire deploy session of loading model: interpreter 0"; + auto start = std::chrono::steady_clock::now(); + model_.acquireSession(&manager_->allInstances().at(0)); + LOG(INFO) << " - finished pre-acquire deploy session, interpreter 0, by " + << getTimeElapsedMS(start).count() / 1000 << "s"; } + std::unique_lock lock(warmUpMutex_); + warmUpCV_.wait(lock, [&] { return warmUpCounter_ == numThreadsPerGPU_ - 1; }); + completionExecutor_ = std::make_unique(2 * numThreadsPerGPU_); } @@ -122,7 +153,8 @@ GPUExecutor::~GPUExecutor() { std::shared_ptr batch; while (batches_.readIfNotEmpty(batch)) { rejectionExecutor_->add([batch = std::move(batch)]() { - handleBatchException(batch->contexts, "Server shutdown"); + handleBatchException( + batch->contexts, "Server shutdown"); }); } } @@ -153,6 +185,21 @@ void GPUExecutor::process(int idx) { warmupFn_(); } + if (idx != 0) { + std::lock_guard lock(warmUpAcquireSessionMutex_); + LOG(INFO) << " - Pre-acquire deploy session of loading model, interpreter " + << idx; + auto start = std::chrono::steady_clock::now(); + model_.acquireSession(&manager_->allInstances().at(idx)); + { + std::lock_guard lock(warmUpMutex_); + warmUpCounter_++; + warmUpCV_.notify_one(); + } + LOG(INFO) << " - finished pre-acquire deploy session, interpreter " << idx + << ", by " << getTimeElapsedMS(start).count() / 1000 << "s"; + } + while (true) { std::shared_ptr batch; batches_.blockingRead(batch); @@ -165,13 +212,18 @@ void GPUExecutor::process(int idx) { continue; } + if (!batch->contexts.empty()) { + folly::RequestContext::setContext(batch->contexts[0].follyRequestContext); + } + auto timeInQueue = getTimeElapsedMS(batch->enqueueTime); observer_->recordQueueLatency(timeInQueue.count()); if (timeInQueue >= queueTimeout_) { observer_->addQueueTimeoutCount(1); rejectionExecutor_->add([batch = std::move(batch)]() { - handleBatchException(batch->contexts, "GPUExecutor queue timeout"); + handleBatchException( + batch->contexts, "GPUExecutor queue timeout"); }); continue; @@ -181,6 +233,11 @@ void GPUExecutor::process(int idx) { auto model = model_.acquireSession(&manager_->allInstances().at(idx)); at::IValue predictions; + LOG_EVERY_N(INFO, 10000) + << "GPU " << rank_ << " is running batch size " << batch->batchSize + << ", avg request size " << batch->batchSize / batch->contexts.size(); + + std::string exWhat = ""; try { RECORD_USER_SCOPE("Forward"); // Block current stream until H2D finishes. @@ -188,15 +245,39 @@ void GPUExecutor::process(int idx) { auto forwardStart = std::chrono::steady_clock::now(); + // Disable automatic garbage collection + if (gcConfig_->optimizationEnabled) { + model.global("gc", "disable")(at::ArrayRef()); + } + predictions = model.self.attr("__call__")({std::move(batch->forwardArgs)}) .toIValue(); + // Manually call Python's garbage collector + if (gcConfig_->optimizationEnabled) { + gcConfig_->threadIdToNumForwards[idx] += 1; + if (gcConfig_->threadIdToNumForwards[idx] % gcConfig_->collectionFreq == + 0) { + model.global("gc", "collect")(at::ArrayRef()); + } + // Report gc stats + if (gcConfig_->threadIdToNumForwards[idx] % + gcConfig_->statReportingFreq == + 0) { + reportGCStats( + model + .global("gc", "get_stats")(at::ArrayRef()) + .toIValue()); + } + } + observer_->recordPredictionLatency( getTimeElapsedMS(forwardStart).count()); } catch (const std::exception& ex) { - // The observer will record this in the completion executor. Don't observe - // twice. + // The observer will record this in the completion executor. Don't + // observe twice. LOG_EVERY_N(ERROR, 100) << "Exception during predict, msg: " << ex.what(); + exWhat = ex.what(); } batch->event->record(); @@ -209,7 +290,8 @@ void GPUExecutor::process(int idx) { resultSplitFunc = resultSplitFunc_, rank = rank_, d2hStream = d2hStream, - observer = observer_.get()]() mutable { + observer = observer_.get(), + exWhat = exWhat]() mutable { RECORD_USER_SCOPE("CompletionStage"); c10::InferenceMode imGuard; @@ -230,13 +312,19 @@ void GPUExecutor::process(int idx) { getTimeElapsedMS(d2hStart).count(), resultSplitFunc->name()); } + constexpr std::string_view gpuExceptionContext = + "GPUExecutor prediction exception, "; if (predictions.isNone()) { observer->addPredictionExceptionCount(1); - rejectionExecutor_->add( - [contexts = std::move(batch->contexts)]() mutable { - handleBatchException( - contexts, "GPUExecutor prediction exception"); - }); + rejectionExecutor_->add([contexts = std::move(batch->contexts), + gpuExceptionContext = gpuExceptionContext, + exWhat = exWhat]() mutable { + handleBatchException( + contexts, + std::string( + gpuExceptionContext.begin(), gpuExceptionContext.end()) + + exWhat); + }); } else { size_t offset = 0; auto rsfStart = std::chrono::steady_clock::now(); @@ -258,6 +346,26 @@ void GPUExecutor::process(int idx) { observer->recordTotalLatency( getTimeElapsedMS(batch->enqueueTime).count()); }); + + // reset request tracking + folly::RequestContext::setContext(nullptr); + } +} + +void GPUExecutor::reportGCStats(c10::IValue stats) { + const auto generationsList = stats.toList(); + for (const auto generationId : c10::irange(generationsList.size())) { + const auto& collectionsDict = + generationsList.get(generationId).toGenericDict(); + for (auto& entry : collectionsDict) { + const auto& key = entry.key(); + const auto& value = entry.value(); + + auto stat_indicator = + "gc_gen_" + std::to_string(generationId) + "_" + key.toStringRef(); + LOG(INFO) << "GC stat indicator: " << stat_indicator; + gcConfig_->observer->addCount(value.toInt(), stat_indicator); + } } } diff --git a/torchrec/inference/src/ResourceManager.cpp b/torchrec/inference/inference_legacy/src/ResourceManager.cpp similarity index 81% rename from torchrec/inference/src/ResourceManager.cpp rename to torchrec/inference/inference_legacy/src/ResourceManager.cpp index 36d83322b..f95e419f2 100644 --- a/torchrec/inference/src/ResourceManager.cpp +++ b/torchrec/inference/inference_legacy/src/ResourceManager.cpp @@ -23,11 +23,15 @@ namespace torchrec { ResourceManager::ResourceManager( int worldSize, size_t maxOutstandingBatches, - int logFrequency) + int logFrequency, + std::unique_ptr observer) : gpuToOutstandingBatches_(worldSize), allTimeHigh_(worldSize), maxOutstandingBatches_(maxOutstandingBatches), - logFrequency_(logFrequency) {} + logFrequency_(logFrequency), + observer_(std::move(observer)) { + CHECK(observer_ != nullptr); +} bool ResourceManager::occupyDevice( int gpuIdx, @@ -41,8 +45,6 @@ bool ResourceManager::occupyDevice( // With lock, try to get device. std::lock_guard lock(mu_); if (gpuToOutstandingBatches_[gpuIdx] < maxOutstandingBatches_) { - // GPU has too many outstanding batches. Try again later. - // Pick GPU and update stats. LOG_EVERY_N(INFO, logFrequency_) << "Picked device " << gpuIdx << ", with load " @@ -50,11 +52,16 @@ bool ResourceManager::occupyDevice( << " -- gpuToOutstandingBatches_ list <" << folly::join(",", gpuToOutstandingBatches_) << ">. " << " -- all time highs: <" << folly::join(",", allTimeHigh_) - << ">. " - << "Waited: " << waitedFor.count() + << ">. " << "Waited: " << waitedFor.count() << " ms. Slack: " << slack.count() << " ms."; gpuToOutstandingBatches_[gpuIdx] += 1; + observer_->recordAllStats( + gpuToOutstandingBatches_[gpuIdx], + allTimeHigh_[gpuIdx], + waitedFor.count(), + gpuIdx); + if (gpuToOutstandingBatches_[gpuIdx] > allTimeHigh_[gpuIdx]) { allTimeHigh_[gpuIdx] = gpuToOutstandingBatches_[gpuIdx]; } @@ -78,6 +85,11 @@ bool ResourceManager::occupyDevice( waitedFor = std::chrono::duration_cast( std::chrono::steady_clock::now() - startTime); if (waitedFor >= slack) { + observer_->recordAllStats( + gpuToOutstandingBatches_[gpuIdx], + allTimeHigh_[gpuIdx], + waitedFor.count(), + gpuIdx); // We have used up all the slack -- requests should time out. LOG(WARNING) << "Timing out a batch of requests after slack of " << slack.count() << " ms was exceeded!"; @@ -89,6 +101,8 @@ bool ResourceManager::occupyDevice( void ResourceManager::release(int gpuIdx) { std::lock_guard lock(mu_); gpuToOutstandingBatches_[gpuIdx] -= 1; + observer_->addOutstandingRequestsCount( + gpuToOutstandingBatches_[gpuIdx], gpuIdx); } } // namespace torchrec diff --git a/torchrec/inference/src/ResultSplit.cpp b/torchrec/inference/inference_legacy/src/ResultSplit.cpp similarity index 81% rename from torchrec/inference/src/ResultSplit.cpp rename to torchrec/inference/inference_legacy/src/ResultSplit.cpp index 99e6cd5e7..d66d34111 100644 --- a/torchrec/inference/src/ResultSplit.cpp +++ b/torchrec/inference/inference_legacy/src/ResultSplit.cpp @@ -172,47 +172,6 @@ class DictOfTensorsResultSplitFunc : public ResultSplitFunc { } }; -class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { - public: - std::string name() override { - return "dict_with_mask_tensor"; - } - - c10::IValue splitResult( - c10::IValue result, - size_t offset, - size_t length, - size_t /* nTotalLength */) override { - return splitDictWithMaskTensor(result, offset, length); - } - - c10::IValue moveToHost(c10::IValue result) { - const auto& dict = result.toGenericDict(); - c10::impl::GenericDict moved( - c10::StringType::get(), - c10::TupleType::create( - {c10::TensorType::get(), c10::TensorType::get()})); - moved.reserve(dict.size()); - - for (auto& entry : dict) { - const auto& key = entry.key(); - const auto& value = entry.value(); - TORCH_CHECK(value.isTuple()); - const auto tuple = value.toTuple(); - TORCH_CHECK(tuple->elements().size() == 2); - std::vector values; - values.reserve(2); - for (int i = 0; i < 2; ++i) { - const auto& tensor = tuple->elements()[i].toTensor(); - values.push_back( - tensor.to(at::Device(at::kCPU), /* non_blocking */ true)); - } - moved.insert(key, c10::ivalue::Tuple::create(std::move(values))); - } - return moved; - } -}; - REGISTER_TORCHREC_RESULTSPLIT_FUNC(dict_of_tensor, DictOfTensorResultSplitFunc); REGISTER_TORCHREC_RESULTSPLIT_FUNC( @@ -224,4 +183,42 @@ REGISTER_TORCHREC_RESULTSPLIT_FUNC( DictWithMaskTensorResultSplitFunc); } // namespace + +std::string DictWithMaskTensorResultSplitFunc::name() { + return "dict_with_mask_tensor"; +} + +c10::IValue DictWithMaskTensorResultSplitFunc::splitResult( + c10::IValue result, + size_t offset, + size_t length, + size_t /* nTotalLength */) { + return splitDictWithMaskTensor(result, offset, length); +} + +c10::IValue DictWithMaskTensorResultSplitFunc::moveToHost(c10::IValue result) { + const auto& dict = result.toGenericDict(); + c10::impl::GenericDict moved( + c10::StringType::get(), + c10::TupleType::create({c10::TensorType::get(), c10::TensorType::get()})); + moved.reserve(dict.size()); + + for (auto& entry : dict) { + const auto& key = entry.key(); + const auto& value = entry.value(); + TORCH_CHECK(value.isTuple()); + const auto tuple = value.toTuple(); + TORCH_CHECK(tuple->elements().size() == 2); + std::vector values; + values.reserve(2); + for (int i = 0; i < 2; ++i) { + const auto& tensor = tuple->elements()[i].toTensor(); + values.push_back( + tensor.to(at::Device(at::kCPU), /* non_blocking */ true)); + } + moved.insert(key, c10::ivalue::Tuple::create(std::move(values))); + } + return moved; +} + } // namespace torchrec diff --git a/torchrec/inference/src/SingleGPUExecutor.cpp b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp similarity index 86% rename from torchrec/inference/src/SingleGPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp index 1e14a0439..828d88b1a 100644 --- a/torchrec/inference/src/SingleGPUExecutor.cpp +++ b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp @@ -19,26 +19,36 @@ SingleGPUExecutor::SingleGPUExecutor( ExecInfos execInfos, size_t numGpu, std::shared_ptr observer, - c10::Device resultDevice) + c10::Device resultDevice, + size_t numProcessThreads, + bool useHighPriCudaStream) : manager_(manager), execInfos_(std::move(execInfos)), numGpu_(numGpu), + numProcessThreads_(numProcessThreads), + useHighPriCudaStream_(useHighPriCudaStream), resultDevice_(resultDevice), observer_(observer), requests_(kQUEUE_CAPACITY), + processExecutor_( + std::make_unique(numProcessThreads)), completionExecutor_( std::make_unique(execInfos_.size())), - roundRobinExecInfoNextIdx_(0u), - processThread_([&]() { process(); }) { - for (const auto& exec_info : execInfos_) { + roundRobinExecInfoNextIdx_(0u) { + for (size_t i = 0; i < numProcessThreads_; ++i) { + processExecutor_->add([&]() { process(); }); + } + for ([[maybe_unused]] const auto& exec_info : execInfos_) { TORCHREC_CHECK(exec_info.interpIdx < manager_->allInstances().size()); } TORCHREC_CHECK(observer_); } SingleGPUExecutor::~SingleGPUExecutor() { - requests_.blockingWrite(nullptr); - processThread_.join(); + for (size_t i = 0; i < numProcessThreads_; ++i) { + requests_.blockingWrite(nullptr); + } + processExecutor_->join(); completionExecutor_->join(); } @@ -96,7 +106,8 @@ void SingleGPUExecutor::process() { c10::InferenceMode inferenceModeGuard; std::vector streams; for (size_t i = 0; i < numGpu_; ++i) { - streams.push_back(at::cuda::getStreamFromPool(i)); + streams.push_back( + at::cuda::getStreamFromPool(useHighPriCudaStream_, i /* device */)); } at::cuda::CUDAMultiStreamGuard streamGuard(streams); @@ -165,6 +176,7 @@ void SingleGPUExecutor::process() { request->event->synchronize(); for (auto& context : request->contexts) { auto response = std::make_unique(); + response->batchSize = context.batchSize; response->predictions = result; context.promise.setValue(std::move(response)); } diff --git a/torchrec/inference/src/TestUtils.cpp b/torchrec/inference/inference_legacy/src/TestUtils.cpp similarity index 87% rename from torchrec/inference/src/TestUtils.cpp rename to torchrec/inference/inference_legacy/src/TestUtils.cpp index 6e574a46f..e57138622 100644 --- a/torchrec/inference/src/TestUtils.cpp +++ b/torchrec/inference/inference_legacy/src/TestUtils.cpp @@ -8,6 +8,7 @@ #include "torchrec/inference/TestUtils.h" +#include #include #include @@ -132,4 +133,19 @@ at::Tensor createEmbeddingTensor( return tensor; } +c10::List createIValueList( + const std::vector>& input) { + // Input is batch x num_features + std::vector rows; + for (const auto& vec : input) { + rows.push_back(at::tensor(vec, at::TensorOptions().dtype(c10::kFloat))); + } + auto combined = at::stack(rows).transpose(0, 1); + c10::List retList; + for (auto& tensor : combined.split(1)) { + retList.push_back(tensor.squeeze(0)); + } + return retList; +} + } // namespace torchrec diff --git a/torchrec/inference/src/Validation.cpp b/torchrec/inference/inference_legacy/src/Validation.cpp similarity index 96% rename from torchrec/inference/src/Validation.cpp rename to torchrec/inference/inference_legacy/src/Validation.cpp index 2ad65726e..4ad3391e1 100644 --- a/torchrec/inference/src/Validation.cpp +++ b/torchrec/inference/inference_legacy/src/Validation.cpp @@ -14,7 +14,7 @@ namespace torchrec { bool validateSparseFeatures( at::Tensor& values, at::Tensor& lengths, - c10::optional maybeWeights) { + std::optional maybeWeights) { auto flatLengths = lengths.view(-1); // validate sum of lengths equals number of values/weights diff --git a/torchrec/inference/inference_legacy/state_dict_transform.py b/torchrec/inference/inference_legacy/state_dict_transform.py new file mode 100644 index 000000000..0379b1b80 --- /dev/null +++ b/torchrec/inference/inference_legacy/state_dict_transform.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Union + +import torch +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor + + +def state_dict_gather( + src: Dict[str, Union[torch.Tensor, ShardedTensor]], + dst: Dict[str, torch.Tensor], +) -> None: + """ + Gathers the values of the src state_dict of the keys present in the dst state_dict. Can handle ShardedTensors in the src state_dict. + + Args: + src (Dict[str, Union[torch.Tensor, ShardedTensor]]): source's state_dict for this rank + dst (Dict[str, torch.Tensor]): destination's state_dict + """ + for key, dst_tensor in dst.items(): + src_tensor = src[key] + if isinstance(src_tensor, ShardedTensor): + src_tensor.gather(out=dst_tensor if (dist.get_rank() == 0) else None) + elif isinstance(src_tensor, torch.Tensor): + dst_tensor.copy_(src_tensor) + else: + raise ValueError(f"Unsupported tensor {key} type {type(src_tensor)}") + + +def state_dict_all_gather_keys( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, +) -> List[str]: + """ + Gathers all the keys of the state_dict from all ranks. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): keys of this state_dict will be gathered + pg (ProcessGroup): Process Group used for comms + """ + names = list(state_dict.keys()) + all_names = [None] * dist.get_world_size(pg) + dist.all_gather_object(all_names, names, pg) + deduped_names = set() + for local_names in all_names: + # pyre-ignore[16] + for name in local_names: + deduped_names.add(name) + return sorted(deduped_names) + + +def state_dict_to_device( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, + device: torch.device, +) -> Dict[str, Union[torch.Tensor, ShardedTensor]]: + """ + Moves a state_dict to a device with a process group. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): state_dict to move + pg (ProcessGroup): Process Group used for comms + device (torch.device): device to put state_dict on + """ + ret = {} + all_keys = state_dict_all_gather_keys(state_dict, pg) + for key in all_keys: + if key in state_dict: + tensor = state_dict[key] + if isinstance(tensor, ShardedTensor): + copied_shards = [ + Shard.from_tensor_and_offsets( + tensor=shard.tensor.to(device), + shard_offsets=shard.metadata.shard_offsets, + rank=dist.get_rank(pg), + ) + for shard in tensor.local_shards() + ] + ret[key] = ShardedTensor._init_from_local_shards( + copied_shards, + tensor.metadata().size, + process_group=pg, + ) + elif isinstance(tensor, torch.Tensor): + ret[key] = tensor.to(device) + else: + raise ValueError(f"Unsupported tensor {key} type {type(tensor)}") + else: + # No state_dict entries for table-wise sharding, + # but need to follow full-sync. + ret[key] = ShardedTensor._init_from_local_shards( + [], + [], + process_group=pg, + ) + return ret diff --git a/torchrec/inference/tests/BatchingQueueTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp similarity index 84% rename from torchrec/inference/tests/BatchingQueueTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp index aaa8e8fa0..86ff3bde1 100644 --- a/torchrec/inference/tests/BatchingQueueTest.cpp +++ b/torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp @@ -86,13 +86,18 @@ TEST(BatchingQueueTest, Basic) { std::this_thread::sleep_for(std::chrono::seconds(1)); } - ASSERT_EQ(2 * (2 + 4), value->forwardArgs.at("cuda_features").numel()); - ASSERT_EQ(value->forwardArgs.at("cuda_features").device().type(), at::kCUDA); - ASSERT_EQ(2 * (2 + 4), value->forwardArgs.at("cpu_features").numel()); - ASSERT_EQ(value->forwardArgs.at("cpu_features").device(), at::kCPU); + ASSERT_EQ( + 2 * (2 + 4), value->forwardArgs.at("cuda_features").toTensor().numel()); + ASSERT_EQ( + value->forwardArgs.at("cuda_features").toTensor().device().type(), + at::kCUDA); + ASSERT_EQ( + 2 * (2 + 4), value->forwardArgs.at("cpu_features").toTensor().numel()); + ASSERT_EQ( + value->forwardArgs.at("cpu_features").toTensor().device(), at::kCPU); ASSERT_TRUE(at::allclose( - value->forwardArgs.at("cuda_features").cpu(), - value->forwardArgs.at("cpu_features"))); + value->forwardArgs.at("cuda_features").toTensor().cpu(), + value->forwardArgs.at("cpu_features").toTensor())); } TEST(BatchingQueueTest, MaxBatchSize) { @@ -132,8 +137,9 @@ TEST(BatchingQueueTest, MaxBatchSize) { std::this_thread::sleep_for(std::chrono::seconds(1)); } - ASSERT_EQ(2 * 2, value->forwardArgs.at("cpu_features").numel()); - ASSERT_EQ(value->forwardArgs.at("cpu_features").device(), at::kCPU); + ASSERT_EQ(2 * 2, value->forwardArgs.at("cpu_features").toTensor().numel()); + ASSERT_EQ( + value->forwardArgs.at("cpu_features").toTensor().device(), at::kCPU); } } // namespace torchrec diff --git a/torchrec/inference/tests/BatchingTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingTest.cpp similarity index 55% rename from torchrec/inference/tests/BatchingTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingTest.cpp index 701b6bf81..4402da977 100644 --- a/torchrec/inference/tests/BatchingTest.cpp +++ b/torchrec/inference/inference_legacy/tests/BatchingTest.cpp @@ -17,6 +17,8 @@ #include #include +#include "ATen/ops/tensor.h" +#include "torch/library.h" #include "torchrec/inference/TestUtils.h" #include "torchrec/inference/Types.h" @@ -44,24 +46,45 @@ TEST(BatchingTest, SparseCombineTest) { auto batched = combineSparse("id_score_list_features", {request0, request1}, true); - checkTensor(batched["id_score_list_features.lengths"], {2, 0, 1, 1}); - checkTensor(batched["id_score_list_features.values"], {0, 1, 2, 3}); + checkTensor( + batched["id_score_list_features.lengths"].toTensor(), {2, 0, 1, 1}); + checkTensor( + batched["id_score_list_features.values"].toTensor(), {0, 1, 2, 3}); checkTensor( - batched["id_score_list_features.weights"], {1.0f, 1.0f, 1.0f, 1.0f}); + batched["id_score_list_features.weights"].toTensor(), + {1.0f, 1.0f, 1.0f, 1.0f}); } TEST(BatchingTest, EmbeddingCombineTest) { - const auto embedding0 = createEmbeddingTensor({{0, 1}, {2, 3}}); - const auto embedding1 = createEmbeddingTensor({{4, 5}}); + std::vector> raw_emb0 = {{0, 1}, {2, 3}}; + std::vector> raw_emb1 = {{4, 5}}; + + const auto embedding0 = createEmbeddingTensor(raw_emb0); + const auto embedding1 = createEmbeddingTensor(raw_emb1); auto request0 = createRequest(2, 2, embedding0); auto request1 = createRequest(1, 2, embedding1); + request0->features["ivalue_embedding_features"] = createIValueList(raw_emb0); + request1->features["ivalue_embedding_features"] = createIValueList(raw_emb1); + + auto batched_dict = + combineEmbedding("embedding_features", {request0, request1}); + auto batchedIValue = + combineEmbedding("ivalue_embedding_features", {request0, request1}); - auto batched = combineEmbedding("embedding_features", {request0, request1}); // num features, num batches, feature dimision - EXPECT_EQ(batched["embedding_features"].sizes(), at::ArrayRef({2L, 3L, 1L})); - auto flatten = batched["embedding_features"].flatten(); - checkTensor(flatten, {0, 1, 4, 2, 3, 5}); + auto batched = at::stack(batched_dict["embedding_features"].toTensorVector()); + auto ivalue_batched = + at::stack(batchedIValue["ivalue_embedding_features"].toTensorVector()); + + std::vector expectShape{2L, 3L, 1L}; + std::vector expectResult{0, 2, 4, 1, 3, 5}; + EXPECT_EQ(batched.sizes(), expectShape); + EXPECT_EQ(ivalue_batched.sizes(), expectShape); + auto flatten = batched.flatten(); + checkTensor(flatten, expectResult); + flatten = ivalue_batched.flatten(); + checkTensor(flatten, expectResult); } TEST(BatchingTest, DenseCombineTest) { @@ -80,11 +103,11 @@ TEST(BatchingTest, DenseCombineTest) { // num features, num batches, feature dimension std::vector expectShape{3L, 2L}; std::vector expectResult{1.1, 2.0, 0.3, 1.2, 0.9, 2.3}; - EXPECT_EQ(batchedIOBuf["io_buf"].sizes(), expectShape); - EXPECT_EQ(batchedIValue["ivalue"].sizes(), expectShape); - auto flatten = batchedIOBuf["io_buf"].flatten(); + EXPECT_EQ(batchedIOBuf["io_buf"].toTensor().sizes(), expectShape); + EXPECT_EQ(batchedIValue["ivalue"].toTensor().sizes(), expectShape); + auto flatten = batchedIOBuf["io_buf"].toTensor().flatten(); checkTensor(flatten, expectResult); - flatten = batchedIValue["ivalue"].flatten(); + flatten = batchedIValue["ivalue"].toTensor().flatten(); checkTensor(flatten, expectResult); } diff --git a/torchrec/inference/tests/ResultSplitTest.cpp b/torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp similarity index 100% rename from torchrec/inference/tests/ResultSplitTest.cpp rename to torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp diff --git a/torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp new file mode 100644 index 000000000..7c0823061 --- /dev/null +++ b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include // @manual +#include +#include // @manual +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/SingleGPUExecutor.h" +#include "torchrec/inference/Types.h" + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + int rc = RUN_ALL_TESTS(); + return rc; +} + +const char* path(const char* envname, const char* path) { + const char* e = std::getenv(envname); + return e ? e : path; +} + +std::vector get_input_example( + torch::deploy::InterpreterSession& model_interpreter_session) { + auto eg = model_interpreter_session.self + .attr("load_pickle")({"model", "example.pkl"}) + .toIValue(); + return eg.toTupleRef().elements(); +} + +void assert_tensors_eq(const at::Tensor& expected, const at::Tensor& got) { + ASSERT_TRUE(expected.allclose(got, 1e-03, 1e-05)); +} + +c10::IValue execute( + torchrec::SingleGPUExecutor& executor, + const std::string& methodName, + std::vector args) { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + + executor.schedule(std::make_shared( + methodName, std::move(args), std::move(promise))); + return std::move(future).get()->predictions; +} + +TEST(TorchDeployGPUTest, SimpleModel_multiGPU) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "Test is skipped as it requires CUDA."; + } + + const size_t numGpu = torch::cuda::device_count(); + if (numGpu <= 1) { + GTEST_SKIP() << "Test is skipped as it requires > 1 CUDA devices, found:" + << numGpu; + } + + const char* model_filename = path("TORCH_PACKAGE_SIMPLE", "/tmp/simple"); + + auto device = c10::Device(c10::kCUDA, 0); + + auto manager = + std::make_shared(2 * numGpu); + torch::deploy::Package package = manager->loadPackage(model_filename); + + std::vector models; + torch::deploy::ReplicatedObj model_control; + const size_t gpu_rank_control = 0; + + { + auto I = package.acquireSession(); + + auto pyModel = I.fromMovable(package.loadPickle("model", "model.pkl")); + + for (size_t i = 0; i < numGpu; i++) { + auto model = I.createMovable( + pyModel.attr("to")(c10::IValue(c10::Device(c10::kCUDA, i)))); + if (i == gpu_rank_control) { + model_control = model; + } + models.push_back(std::move(model)); + } + } + + std::vector> workExecutors; + for (size_t i = 0; i < numGpu; i++) { + const std::vector interp_idxs = {static_cast(i)}; + workExecutors.push_back(std::make_unique( + manager, + torchrec::SingleGPUExecutor::ExecInfos{{i, numGpu + i, models[i]}}, + numGpu)); + } + + std::vector execInfos; + for (size_t i = 0; i < numGpu; i++) { + execInfos.push_back({i, numGpu + i, models[i]}); + } + + auto controlExecutor = + std::make_unique(manager, execInfos, numGpu); + + std::vector example_inputs; + { + auto I = package.acquireSession(); + example_inputs = get_input_example(I); + } + auto example_input0 = example_inputs[0].toTensor(); + auto expected_forward0 = example_input0 + at::ones(example_input0.sizes()); + + for (size_t i = 0; i < numGpu; i++) { + auto result = + execute(*workExecutors[i], "forward", example_inputs).toTensor(); + assert_tensors_eq(expected_forward0, result); + } + + execute(*controlExecutor, "set_weight", {at::zeros(example_input0.sizes())}); + + auto checkFn = [&](size_t set_weight_count) { + for (size_t i = 0; i < numGpu; i++) { + auto result = + execute(*workExecutors[i], "forward", example_inputs).toTensor(); + if (i < set_weight_count) { + assert_tensors_eq(example_input0, result); + } else { + assert_tensors_eq(expected_forward0, result); + } + } + }; + checkFn(1u); + + execute(*controlExecutor, "set_weight", {at::zeros(example_input0.sizes())}); + checkFn(2u); +} diff --git a/torchrec/inference/tests/SingleGPUExecutorTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp similarity index 69% rename from torchrec/inference/tests/SingleGPUExecutorTest.cpp rename to torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp index db7771cfd..a4213647a 100644 --- a/torchrec/inference/tests/SingleGPUExecutorTest.cpp +++ b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp @@ -41,7 +41,7 @@ void assert_tensors_eq(const at::Tensor& expected, const at::Tensor& got) { TEST(TorchDeployGPUTest, SimpleModelSingleGPU) { if (!torch::cuda::is_available()) { - GTEST_SKIP(); + GTEST_SKIP() << "Test is skipped as it requires CUDA."; } const char* model_filename = path("TORCH_PACKAGE_SIMPLE", "/tmp/simple"); @@ -121,96 +121,9 @@ c10::IValue execute( return std::move(future).get()->predictions; } -TEST(TorchDeployGPUTest, SimpleModel_multiGPU) { - if (!torch::cuda::is_available()) { - GTEST_SKIP(); - } - - const size_t numGpu = torch::cuda::device_count(); - if (numGpu <= 1) { - GTEST_SKIP(); - } - - const char* model_filename = path("TORCH_PACKAGE_SIMPLE", "/tmp/simple"); - - auto device = c10::Device(c10::kCUDA, 0); - - auto manager = - std::make_shared(2 * numGpu); - torch::deploy::Package package = manager->loadPackage(model_filename); - - std::vector models; - torch::deploy::ReplicatedObj model_control; - const size_t gpu_rank_control = 0; - - { - auto I = package.acquireSession(); - - auto pyModel = I.fromMovable(package.loadPickle("model", "model.pkl")); - - for (size_t i = 0; i < numGpu; i++) { - auto model = I.createMovable( - pyModel.attr("to")(c10::IValue(c10::Device(c10::kCUDA, i)))); - if (i == gpu_rank_control) { - model_control = model; - } - models.push_back(std::move(model)); - } - } - - std::vector> workExecutors; - for (size_t i = 0; i < numGpu; i++) { - const std::vector interp_idxs = {static_cast(i)}; - workExecutors.push_back(std::make_unique( - manager, - torchrec::SingleGPUExecutor::ExecInfos{{i, numGpu + i, models[i]}}, - numGpu)); - } - - std::vector execInfos; - for (size_t i = 0; i < numGpu; i++) { - execInfos.push_back({i, numGpu + i, models[i]}); - } - - auto controlExecutor = - std::make_unique(manager, execInfos, numGpu); - - std::vector example_inputs; - { - auto I = package.acquireSession(); - example_inputs = get_input_example(I); - } - auto example_input0 = example_inputs[0].toTensor(); - auto expected_forward0 = example_input0 + at::ones(example_input0.sizes()); - - for (size_t i = 0; i < numGpu; i++) { - auto result = - execute(*workExecutors[i], "forward", example_inputs).toTensor(); - assert_tensors_eq(expected_forward0, result); - } - - execute(*controlExecutor, "set_weight", {at::zeros(example_input0.sizes())}); - - auto checkFn = [&](size_t set_weight_count) { - for (size_t i = 0; i < numGpu; i++) { - auto result = - execute(*workExecutors[i], "forward", example_inputs).toTensor(); - if (i < set_weight_count) { - assert_tensors_eq(example_input0, result); - } else { - assert_tensors_eq(expected_forward0, result); - } - } - }; - checkFn(1u); - - execute(*controlExecutor, "set_weight", {at::zeros(example_input0.sizes())}); - checkFn(2u); -} - TEST(TorchDeployGPUTest, NestedModelSingleGPU) { if (!torch::cuda::is_available()) { - GTEST_SKIP(); + GTEST_SKIP() << "Test is skipped as it requires CUDA."; } const char* model_filename = path("TORCH_PACKAGE_NESTED", "/tmp/nested"); @@ -248,14 +161,14 @@ class TestSingleGPUExecutorObserver : public torchrec::EmptySingleGPUExecutorObserver { public: double requestCount = 0.f; - void addRequestsCount(double value) override { + void addRequestsCount(uint32_t value) override { requestCount += value; } }; TEST(TorchDeployGPUTest, SimpleModelSingleGPUObserver) { if (!torch::cuda::is_available()) { - GTEST_SKIP(); + GTEST_SKIP() << "Test is skipped as it requires CUDA."; } const char* model_filename = path("TORCH_PACKAGE_NESTED", "/tmp/simple"); diff --git a/torchrec/inference/tests/ValidationTest.cpp b/torchrec/inference/inference_legacy/tests/ValidationTest.cpp similarity index 100% rename from torchrec/inference/tests/ValidationTest.cpp rename to torchrec/inference/inference_legacy/tests/ValidationTest.cpp diff --git a/torchrec/inference/tests/generate_test_packages.py b/torchrec/inference/inference_legacy/tests/generate_test_packages.py similarity index 61% rename from torchrec/inference/tests/generate_test_packages.py rename to torchrec/inference/inference_legacy/tests/generate_test_packages.py index 468a75dc7..986ac3d28 100644 --- a/torchrec/inference/tests/generate_test_packages.py +++ b/torchrec/inference/inference_legacy/tests/generate_test_packages.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -16,15 +18,33 @@ import torch from torch.package import PackageExporter -try: - from .test_modules import Nested, Simple -except ImportError: - from test_modules import Nested, Simple # pyre-ignore + +class Simple(torch.nn.Module): + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(N, M)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.weight + input + return output + + def set_weight(self, weight: torch.Tensor) -> None: + self.weight[:] = torch.nn.Parameter(weight) + + +class Nested(torch.nn.Module): + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.simple = Simple(N, M) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.simple(input) def save( name: str, model: torch.nn.Module, eg: Optional[Tuple] = None # pyre-ignore ) -> None: + # pyre-fixme[10]: Name `p` is used but not defined. with PackageExporter(str(p / name)) as e: e.mock("iopath.**") e.intern("**") @@ -41,10 +61,12 @@ def post_process(model: torch.nn.Module) -> None: parser = argparse.ArgumentParser(description="Generate Examples") parser.add_argument("--install_dir", help="Root directory for all output files") -if __name__ == "__main__": - args = parser.parse_args() # pyre-ignore + +def main() -> None: + global p + args = parser.parse_args() if args.install_dir is None: - p = Path(__file__).parent / "generated" # pyre-ignore + p = Path(__file__).parent / "generated" p.mkdir(exist_ok=True) else: p = Path(args.install_dir) @@ -57,3 +79,7 @@ def post_process(model: torch.nn.Module) -> None: save("simple", simple, (torch.rand(10, 20),)) save("nested", nested, (torch.rand(10, 20),)) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/torchrec/inference/tests/model_packager_tests.py b/torchrec/inference/inference_legacy/tests/model_packager_tests.py similarity index 99% rename from torchrec/inference/tests/model_packager_tests.py rename to torchrec/inference/inference_legacy/tests/model_packager_tests.py index d007c0f00..41533e34a 100644 --- a/torchrec/inference/tests/model_packager_tests.py +++ b/torchrec/inference/inference_legacy/tests/model_packager_tests.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import io import tempfile import unittest diff --git a/torchrec/inference/tests/predict_module_tests.py b/torchrec/inference/inference_legacy/tests/predict_module_tests.py similarity index 100% rename from torchrec/inference/tests/predict_module_tests.py rename to torchrec/inference/inference_legacy/tests/predict_module_tests.py diff --git a/torchrec/inference/inference_legacy/tests/test_modules.py b/torchrec/inference/inference_legacy/tests/test_modules.py new file mode 100644 index 000000000..2b4e97869 --- /dev/null +++ b/torchrec/inference/inference_legacy/tests/test_modules.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +# pyre-strict + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 +# @nolint + +import unittest + +from torchrec.distributed.test_utils.infer_utils import TorchTypesModelInputWrapper +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.inference.modules import quantize_inference_model, shard_quant_model +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class EagerModelProcessingTests(unittest.TestCase): + def test_quantize_shard_cuda(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=10, + embedding_dim=4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(10) + ] + + model = TorchTypesModelInputWrapper( + TestSparseNN( + tables=tables, + ) + ) + + quantized_model = quantize_inference_model(model) + sharded_model, _ = shard_quant_model(quantized_model) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`. + sharded_qebc = sharded_model._module.sparse.ebc + self.assertEqual(len(sharded_qebc.tbes), 1) diff --git a/torchrec/inference/model_packager.py b/torchrec/inference/model_packager.py index af8004bd1..3c18383cf 100644 --- a/torchrec/inference/model_packager.py +++ b/torchrec/inference/model_packager.py @@ -5,10 +5,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc from pathlib import Path -from typing import Any, BinaryIO, Dict, Type, TypeVar, Union +from typing import Any, BinaryIO, Dict, List, Type, TypeVar, Union +import torch from torch.package import PackageExporter from torchrec.inference.modules import PredictFactory @@ -62,14 +65,18 @@ def save_predict_factory( output: Union[str, Path, BinaryIO], extra_files: Dict[str, Union[str, bytes]], loader_code: str = LOADER_CODE, + package_importer: Union[ + torch.package.Importer, List[torch.package.Importer] + ] = torch.package.sys_importer, ) -> None: - with PackageExporter(output) as pe: + with PackageExporter(output, importer=package_importer) as pe: # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], # Type[PredictFactoryPackager]]` is not a function. cls.set_extern_modules(pe) # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], # Type[PredictFactoryPackager]]` is not a function. cls.set_mocked_modules(pe) + pe.extern(["sys"]) pe.intern("**") for k, v in extra_files.items(): if isinstance(v, str): @@ -90,8 +97,15 @@ def _save_predict_factory( configs: Dict[str, Any], loader_code: str = LOADER_CODE, ) -> None: + # If predict_factory is coming from a torch package, + # __module__ would have prefix. + # To save such predict factory, we need to remove + # the prefix. + package_name = predict_factory.__module__ + if package_name.startswith(" str: + if typename.startswith(" prefix. + typename = ".".join(typename.split(".")[1:]) + return typename + + +DEFAULT_FUSED_PARAMS: Dict[str, Any] = { + FUSED_PARAM_REGISTER_TBE_BOOL: True, + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True, + FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False, +} + +DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [ + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingBagCollectionSharder(fused_params=DEFAULT_FUSED_PARAMS), + ), + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingCollectionSharder(fused_params=DEFAULT_FUSED_PARAMS), + ), + cast( + ModuleSharder[torch.nn.Module], + QuantFeatureProcessedEmbeddingBagCollectionSharder( + fused_params=DEFAULT_FUSED_PARAMS + ), + ), +] + +DEFAULT_QUANT_MAPPING: Dict[str, Type[torch.nn.Module]] = { + trim_torch_package_prefix_from_typename( + torch.typename(EmbeddingBagCollection) + ): QuantEmbeddingBagCollection, + trim_torch_package_prefix_from_typename( + torch.typename(EmbeddingCollection) + ): QuantEmbeddingCollection, +} + +DEFAULT_QUANTIZATION_DTYPE: torch.dtype = torch.int8 + +FEATURE_PROCESSED_EBC_TYPE: str = trim_torch_package_prefix_from_typename( + torch.typename(FeatureProcessedEmbeddingBagCollection) +) def quantize_feature( @@ -27,10 +125,12 @@ def quantize_feature( ) -> Tuple[torch.Tensor, ...]: return tuple( [ - input.half() - if isinstance(input, torch.Tensor) - and input.dtype in [torch.float32, torch.float64] - else input + ( + input.half() + if isinstance(input, torch.Tensor) + and input.dtype in [torch.float32, torch.float64] + else input + ) for input in inputs ] ) @@ -43,12 +143,14 @@ def quantize_embeddings( additional_qconfig_spec_keys: Optional[List[Type[nn.Module]]] = None, additional_mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None, output_dtype: torch.dtype = torch.float, + per_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, ) -> nn.Module: - qconfig = quant.QConfig( + qconfig = QuantConfig( activation=quant.PlaceholderObserver.with_args(dtype=output_dtype), weight=quant.PlaceholderObserver.with_args(dtype=dtype), + per_table_weight_dtype=per_table_weight_dtype, ) - qconfig_spec: Dict[Type[nn.Module], quant.QConfig] = { + qconfig_spec: Dict[Type[nn.Module], QuantConfig] = { trec.EmbeddingBagCollection: qconfig, } mapping: Dict[Type[nn.Module], Type[nn.Module]] = { @@ -67,18 +169,6 @@ def quantize_embeddings( ) -class CopyableMixin(nn.Module): - def copy( - self, - device: torch.device, - ) -> nn.Module: - return copy_to_device( - self, - current_device=torch.device("cpu"), - to_device=device, - ) - - @dataclass class QualNameMetadata: need_preproc: bool @@ -160,6 +250,14 @@ def qualname_metadata(self) -> Dict[str, QualNameMetadata]: """ return {} + def qualname_metadata_json(self) -> str: + """ + Serialize the qualname metadata to JSON, for ease of parsing with torch::deploy environments. + """ + return json.dumps( + {key: asdict(value) for key, value in self.qualname_metadata().items()} + ) + def model_inputs_data(self) -> Dict[str, Any]: """ Returns a dict of various data for benchmarking input generation. @@ -190,11 +288,12 @@ class PredictModule(nn.Module): def __init__( self, module: nn.Module, + device: Optional[str] = None, ) -> None: super().__init__() self._module: nn.Module = module # lazy init device from thread inited device guard - self._device: Optional[torch.device] = None + self._device: Optional[torch.device] = torch.device(device) if device else None self._module.eval() @property @@ -212,7 +311,7 @@ def predict_forward(self, batch: Dict[str, torch.Tensor]) -> Any: def forward(self, batch: Dict[str, torch.Tensor]) -> Any: if self._device is None: self._device = torch.device("cuda", torch.cuda.current_device()) - with torch.cuda.device(self._device), torch.inference_mode(): + with torch.inference_mode(): return self.predict_forward(batch) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. @@ -244,7 +343,6 @@ def quantize_dense( ): if dtype == torch.half: new_mod = mod.half() - # pyre-ignore [6] new_mod.register_forward_pre_hook(quantize_feature) reassign[name] = new_mod else: @@ -254,3 +352,289 @@ def quantize_dense( for key, value in reassign.items(): module._modules[key] = value return predict_module + + +def set_pruning_data( + model: torch.nn.Module, + tables_to_rows_post_pruning: Dict[str, int], + module_types: Optional[List[Type[nn.Module]]] = None, +) -> torch.nn.Module: + if module_types is None: + module_types = [EmbeddingBagCollection, FeatureProcessedEmbeddingBagCollection] + + for _, module in model.named_modules(): + if type(module) in module_types: + setattr( + module, + MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT, + tables_to_rows_post_pruning, + ) + + return model + + +def quantize_inference_model( + model: torch.nn.Module, + quantization_mapping: Optional[Dict[str, Type[torch.nn.Module]]] = None, + per_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, + fp_weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE, + quantization_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE, + output_dtype: torch.dtype = torch.float, +) -> torch.nn.Module: + """ + Quantize the model, module swapping TorchRec train modules with its + quantized counterpart, (e.g. EmbeddingBagCollection -> QuantEmbeddingBagCollection). + + Args: + model (torch.nn.Module): the model to be quantized + quantization_mapping (Optional[Dict[str, Type[torch.nn.Module]]]): a mapping from + the original module type to the quantized module type. If not provided, the default mapping will be used: + (EmbeddingBagCollection -> QuantEmbeddingBagCollection, EmbeddingCollection -> QuantEmbeddingCollection). + per_table_weight_dtype (Optional[Dict[str, torch.dtype]]): a mapping from table name to weight dtype. + If not provided, the default quantization dtype will be used (int8). + fp_weight_dtype (torch.dtype): the desired quantized dtype for feature processor weights in + FeatureProcessedEmbeddingBagCollection if used. Default is int8. + + Returns: + torch.nn.Module: the quantized model + + Example:: + + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + + module = DLRMPredictModule( + embedding_bag_collection=ebc, + dense_in_features=self.model_config.dense_in_features, + dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, + over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, + id_list_features_keys=self.model_config.id_list_features_keys, + dense_device=device, + ) + + quant_model = quantize_inference_model(module) + """ + + if quantization_mapping is None: + quantization_mapping = DEFAULT_QUANT_MAPPING + + def _quantize_fp_module( + model: torch.nn.Module, + fp_module: FeatureProcessedEmbeddingBagCollection, + fp_module_fqn: str, + weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE, + per_fp_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, + ) -> None: + """ + If FeatureProcessedEmbeddingBagCollection is found, quantize via direct module swap. + """ + + quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection]) + quant_prep_enable_cache_features_order( + model, [FeatureProcessedEmbeddingBagCollection] + ) + # pyre-fixme[16]: `FeatureProcessedEmbeddingBagCollection` has no attribute + # `qconfig`. + fp_module.qconfig = QuantConfig( + activation=quant.PlaceholderObserver.with_args(dtype=output_dtype), + weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype), + per_table_weight_dtype=per_fp_table_weight_dtype, + ) + + # ie. "root.submodule.feature_processed_mod" -> "root.submodule", "feature_processed_mod" + fp_ebc_parent_fqn, fp_ebc_name = fp_module_fqn.rsplit(".", 1) + fp_ebc_parent = getattr_recursive(model, fp_ebc_parent_fqn) + fp_ebc_parent.register_module( + fp_ebc_name, + QuantFeatureProcessedEmbeddingBagCollection.from_float(fp_module), + ) + + additional_qconfig_spec_keys = [] + additional_mapping = {} + + for n, m in model.named_modules(): + typename = trim_torch_package_prefix_from_typename(torch.typename(m)) + + if typename in quantization_mapping: + additional_qconfig_spec_keys.append(type(m)) + additional_mapping[type(m)] = quantization_mapping[typename] + elif typename == FEATURE_PROCESSED_EBC_TYPE: + # handle the fp ebc separately + _quantize_fp_module( + model, + m, + n, + weight_dtype=fp_weight_dtype, + # Pass in per_fp_table_weight_dtype if it is provided, perhaps + # fpebc parameters are also in here + per_fp_table_weight_dtype=per_table_weight_dtype, + ) + + quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) + quant_prep_enable_cache_features_order( + model, [EmbeddingBagCollection, EmbeddingCollection] + ) + quantize_embeddings( + model, + dtype=quantization_dtype, + additional_qconfig_spec_keys=additional_qconfig_spec_keys, + additional_mapping=additional_mapping, + inplace=True, + per_table_weight_dtype=per_table_weight_dtype, + output_dtype=output_dtype, + ) + + logger.info( + f"Default quantization dtype is {quantization_dtype}, {per_table_weight_dtype=}." + ) + + return model + + +def shard_quant_model( + model: torch.nn.Module, + world_size: int = 1, + compute_device: str = "cuda", + sharding_device: str = "meta", + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + device_memory_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + ddr_cap: Optional[int] = None, +) -> Tuple[torch.nn.Module, ShardingPlan]: + """ + Shard a quantized TorchRec model, used for generating the most optimal model for inference and + necessary for distributed inference. + + Args: + model (torch.nn.Module): the quantized model to be sharded + world_size (int): the number of devices to shard the model, default to 1 + compute_device (str): the device to run the model, default to "cuda" + sharding_device (str): the device to run the sharding, default to "meta" + sharders (Optional[List[ModuleSharder[torch.nn.Module]]]): sharders to use for sharding + quantized model, default to QuantEmbeddingBagCollectionSharder, QuantEmbeddingCollectionSharder, + QuantFeatureProcessedEmbeddingBagCollectionSharder. + device_memory_size (Optional[int]): the memory limit for cuda devices, default to None + constraints (Optional[Dict[str, ParameterConstraints]]): constraints to use for sharding, default to None + which will then implement default constraints with QuantEmbeddingBagCollection being sharded TableWise + + Returns: + Tuple[torch.nn.Module, ShardingPlan]: the sharded model and the sharding plan + + Example:: + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + + module = DLRMPredictModule( + embedding_bag_collection=ebc, + dense_in_features=self.model_config.dense_in_features, + dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, + over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, + id_list_features_keys=self.model_config.id_list_features_keys, + dense_device=device, + ) + + quant_model = quantize_inference_model(module) + sharded_model, _ = shard_quant_model(quant_model) + """ + + if constraints is None: + table_fqns = [] + sharders = sharders if sharders else DEFAULT_SHARDERS + module_types = [sharder.module_type for sharder in sharders] + for module in model.modules(): + if type(module) in module_types: + # TODO: handle other cases/reduce hardcoding + if hasattr(module, "embedding_bags"): + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` + # is not a function. + for table in module.embedding_bags: + table_fqns.append(table) + + # Default table wise constraints + constraints = {} + for name in table_fqns: + constraints[name] = ParameterConstraints( + sharding_types=[ShardingType.TABLE_WISE.value], + compute_kernels=[EmbeddingComputeKernel.QUANT.value], + ) + + if device_memory_size is not None: + hbm_cap = device_memory_size + elif torch.cuda.is_available() and compute_device == "cuda": + hbm_cap = torch.cuda.get_device_properties( + f"cuda:{torch.cuda.current_device()}" + ).total_memory + else: + hbm_cap = None + + topology = trec_dist.planner.Topology( + world_size=world_size, + compute_device=compute_device, + local_world_size=world_size, + hbm_cap=hbm_cap, + ddr_cap=ddr_cap, + ) + batch_size = 1 + model_plan = trec_dist.planner.EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + constraints=constraints, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + constraints=constraints, + estimator=[ + EmbeddingPerfEstimator( + topology=topology, constraints=constraints, is_inference=True + ), + EmbeddingStorageEstimator(topology=topology, constraints=constraints), + ], + ), + storage_reservation=FixedPercentageStorageReservation( + percentage=0.0, + ), + ).plan( + model, + sharders if sharders else DEFAULT_SHARDERS, + ) + + model = _shard_modules( + module=model, + device=torch.device(sharding_device), + plan=model_plan, + env=trec_dist.ShardingEnv.from_local( + world_size, + 0, + ), + sharders=sharders if sharders else DEFAULT_SHARDERS, + ) + + return model, model_plan + + +def get_table_to_weights_from_tbe( + model: torch.nn.Module, +) -> Dict[str, List[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: + table_to_weight = {} + + for module in model.modules(): + if isinstance(module, IntNBitTableBatchedEmbeddingBagsCodegen): + weights = module.split_embedding_weights() + for i, spec in enumerate(module.embedding_specs): + table_to_weight[spec[0]] = weights[i] + + return table_to_weight + + +def assign_weights_to_tbe( + model: torch.nn.Module, + table_to_weight: Dict[str, List[Tuple[torch.Tensor, Optional[torch.Tensor]]]], +) -> None: + for module in model.modules(): + if isinstance(module, IntNBitTableBatchedEmbeddingBagsCodegen): + q_weights = [] + for spec in module.embedding_specs: + assert spec[0] in table_to_weight, f"{spec[0]} not in table_to_weight" + q_weights.append(table_to_weight[spec[0]]) + + module.assign_embedding_weights(q_weights) + + return diff --git a/torchrec/inference/server.cpp b/torchrec/inference/server.cpp index fda92326b..81582a7ee 100644 --- a/torchrec/inference/server.cpp +++ b/torchrec/inference/server.cpp @@ -10,327 +10,170 @@ #include #include -#include -#include -#include -#include -#include -#include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_format.h" + #include #include +#include + +#include +#include -// remove this after we switch over to multipy externally for torchrec -#ifdef FBCODE_CAFFE2 -#include // @manual -#include +#ifdef BAZEL_BUILD +#include "examples/protos/predictor.grpc.pb.h" #else -#include -#include +#include "predictor.grpc.pb.h" #endif -#include +#define NUM_BYTES_FLOAT_FEATURES 4 +#define NUM_BYTES_SPARSE_FEATURES 4 -#include "torchrec/inference/GPUExecutor.h" -#include "torchrec/inference/predictor.grpc.pb.h" -#include "torchrec/inference/predictor.pb.h" - -using grpc::Channel; -using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; using grpc::Status; + +ABSL_FLAG(uint16_t, port, 50051, "Server port for the service"); + using predictor::FloatVec; using predictor::PredictionRequest; using predictor::PredictionResponse; using predictor::Predictor; -DEFINE_int32(n_interp_per_gpu, 1, ""); -DEFINE_int32(n_gpu, 1, ""); -DEFINE_string(package_path, "", ""); - -DEFINE_int32(batching_interval, 10, ""); -DEFINE_int32(queue_timeout, 500, ""); - -DEFINE_int32(num_exception_threads, 4, ""); -DEFINE_int32(num_mem_pinner_threads, 4, ""); -DEFINE_int32(max_batch_size, 2048, ""); -DEFINE_int32(gpu_executor_queue_timeout, 50, ""); - -DEFINE_string(server_address, "0.0.0.0", ""); -DEFINE_string(server_port, "50051", ""); - -DEFINE_string( - python_packages_path, - "", - "Used to load the packages that you 'extern' with torch.package"); - -namespace { - -std::unique_ptr toTorchRecRequest( - const PredictionRequest* request) { - auto torchRecRequest = std::make_unique(); - torchRecRequest->batch_size = request->batch_size(); - - // Client sends a request with serialized tensor to bytes. - // Byte string is converted to folly::iobuf for torchrec request. - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->float_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["float_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["id_list_features"] = std::move(sparseFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_score_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - auto encoded_weights = feature.weights(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - sparseFeature.weights = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_weights.data(), - encoded_weights.size()}; - - torchRecRequest->features["id_score_list_features"] = - std::move(sparseFeature); - } - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->embedding_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["embedding_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->unary_features(); - auto encoded_lengths = feature.lengths(); - auto encoded_values = feature.values(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["unary_features"] = std::move(sparseFeature); - } - - return torchRecRequest; -} - -// Logic behind the server's behavior. class PredictorServiceHandler final : public Predictor::Service { public: - explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) - : queue_(queue) {} + PredictorServiceHandler(torch::jit::script::Module& module) + : module_(module) {} Status Predict( grpc::ServerContext* context, const PredictionRequest* request, PredictionResponse* reply) override { - folly::Promise> promise; - auto future = promise.getSemiFuture(); - queue_.add(toTorchRecRequest(request), std::move(promise)); - auto torchRecResponse = - std::move(future).get(); // blocking, TODO: Write async server - auto predictions = reply->mutable_predictions(); - - // Convert ivalue to map, TODO: find out if protobuf - // can support custom types (folly::iobuf), so we can avoid this overhead. - for (const auto& item : torchRecResponse->predictions.toGenericDict()) { - auto tensor = item.value().toTensor(); - FloatVec fv; - fv.mutable_data()->Add( - tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); - (*predictions)[item.key().toStringRef()] = fv; - } + std::cout << "Predict Called!" << std::endl; + c10::Dict dict; + + auto floatFeature = request->float_features(); + auto floatFeatureBlob = floatFeature.values(); + auto numFloatFeatures = floatFeature.num_features(); + auto batchSize = + floatFeatureBlob.size() / (NUM_BYTES_FLOAT_FEATURES * numFloatFeatures); + + std::cout << "Size: " << floatFeatureBlob.size() + << " Num Features: " << numFloatFeatures << std::endl; + auto floatFeatureTensor = torch::from_blob( + floatFeatureBlob.data(), + {batchSize, numFloatFeatures}, + torch::kFloat32); + + auto idListFeature = request->id_list_features(); + auto numIdListFeatures = idListFeature.num_features(); + auto lengthsBlob = idListFeature.lengths(); + auto valuesBlob = idListFeature.values(); + + std::cout << "Lengths Size: " << lengthsBlob.size() + << " Num Features: " << numIdListFeatures << std::endl; + assert( + batchSize == + (lengthsBlob.size() / (NUM_BYTES_SPARSE_FEATURES * numIdListFeatures))); + + auto lengthsTensor = torch::from_blob( + lengthsBlob.data(), + {lengthsBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + auto valuesTensor = torch::from_blob( + valuesBlob.data(), + {valuesBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + + dict.insert("float_features", floatFeatureTensor.to(torch::kCUDA)); + dict.insert("id_list_features.lengths", lengthsTensor.to(torch::kCUDA)); + dict.insert("id_list_features.values", valuesTensor.to(torch::kCUDA)); + + std::vector input; + input.push_back(c10::IValue(dict)); + + torch::Tensor output = + this->module_.forward(input).toGenericDict().at("default").toTensor(); + auto predictions = reply->mutable_predictions(); + FloatVec fv; + fv.mutable_data()->Add( + output.data_ptr(), output.data_ptr() + output.numel()); + (*predictions)["default"] = fv; return Status::OK; } private: - torchrec::BatchingQueue& queue_; + torch::jit::script::Module& module_; }; -} // namespace - -int main(int argc, char* argv[]) { - google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); - - LOG(INFO) << "Creating GPU executors"; - - // store the executors and interpreter managers - std::vector> executors; - std::vector models; - std::vector batchQueueCbs; - std::unordered_map batchingMetadataMap; - - std::shared_ptr env = - std::make_shared( - FLAGS_python_packages_path); - - auto manager = std::make_shared( - FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); - { - torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); - auto I = package.acquireSession(); - auto imported = I.self.attr("import_module")({"__module_loader"}); - auto factoryType = imported.attr("MODULE_FACTORY"); - auto factory = factoryType.attr("__new__")({factoryType}); - factoryType.attr("__init__")({factory}); - - // Process forward metadata. - try { - auto batchingMetadataJsonStr = - factory.attr("batching_metadata_json")(at::ArrayRef()) - .toIValue() - .toString() - ->string(); - auto dynamic = folly::parseJson(batchingMetadataJsonStr); - CHECK(dynamic.isObject()); - for (auto it : dynamic.items()) { - torchrec::BatchingMetadata metadata; - metadata.type = it.second["type"].asString(); - metadata.device = it.second["device"].asString(); - batchingMetadataMap[it.first.asString()] = std::move(metadata); - } - } catch (...) { - auto batchingMetadata = - factory.attr("batching_metadata")(at::ArrayRef()) - .toIValue(); - for (const auto& iter : batchingMetadata.toGenericDict()) { - torchrec::BatchingMetadata metadata; - metadata.type = iter.value().toStringRef(); - metadata.device = "cuda"; - batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); - } - } - - // Process result metadata. - auto resultMetadata = - factory.attr("result_metadata")(at::ArrayRef()) - .toIValue() - .toStringRef(); - std::shared_ptr resultSplitFunc = - torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); - - LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; - auto dmp = factory.attr("create_predict_module") - .callKwargs({{"world_size", FLAGS_n_gpu}}); - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto device = I.self.attr("import_module")({"torch"}).attr("device")( - {"cuda", rank}); - auto m = dmp.attr("copy")({device.toIValue()}); - models.push_back(I.createMovable(m)); - } - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto executor = std::make_unique( - manager, - std::move(models[rank]), - rank, - FLAGS_n_gpu, - resultSplitFunc, - std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); - executors.push_back(std::move(executor)); - batchQueueCbs.push_back( - [&, rank](std::shared_ptr batch) { - executors[rank]->callback(std::move(batch)); - }); - } - } - - torchrec::BatchingQueue queue( - batchQueueCbs, - torchrec::BatchingQueue::Config{ - .batchingInterval = - std::chrono::milliseconds(FLAGS_batching_interval), - .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), - .numExceptionThreads = FLAGS_num_exception_threads, - .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, - .maxBatchSize = FLAGS_max_batch_size, - .batchingMetadata = std::move(batchingMetadataMap), - }, - FLAGS_n_gpu); - - // create the server - std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); - auto service = PredictorServiceHandler(queue); +void RunServer(uint16_t port, torch::jit::script::Module& module) { + std::string server_address = absl::StrFormat("0.0.0.0:%d", port); + PredictorServiceHandler service(module); grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); - grpc::ServerBuilder builder; - + ServerBuilder builder; // Listen on the given address without any authentication mechanism. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - // Register "service" as the instance through which we'll communicate with // clients. In this case it corresponds to an *synchronous* service. builder.RegisterService(&service); - // Finally assemble the server. - std::unique_ptr server(builder.BuildAndStart()); - LOG(INFO) << "Server listening on " << server_address; + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; // Wait for the server to shutdown. Note that some other thread must be // responsible for shutting down the server for this call to ever return. server->Wait(); +} + +int main(int argc, char** argv) { + // absl::ParseCommandLine(argc, argv); + + if (argc != 2) { + std::cerr << "usage: ts-infer \n"; + return -1; + } + + std::cout << "Loading model...\n"; + + // deserialize ScriptModule + torch::jit::script::Module module; + try { + module = torch::jit::load(argv[1]); + } catch (const c10::Error&) { + std::cerr << "Error loading model\n"; + return -1; + } + + torch::NoGradGuard no_grad; // ensures that autograd is off + module.eval(); // turn off dropout and other training-time layers/functions + + std::cout << "Sanity Check with dummy inputs" << std::endl; + c10::Dict dict; + dict.insert( + "float_features", + torch::ones( + {1, 13}, torch::dtype(torch::kFloat32).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.lengths", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.values", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + + std::vector input; + input.push_back(c10::IValue(dict)); + + // Execute the model and turn its output into a tensor. + auto output = module.forward(input).toGenericDict().at("default").toTensor(); + std::cout << " Model Forward Completed, Output: " << output.item() + << std::endl; + + RunServer(absl::GetFlag(FLAGS_port), module); - LOG(INFO) << "Shutting down server"; return 0; } diff --git a/torchrec/inference/state_dict_transform.py b/torchrec/inference/state_dict_transform.py index 0379b1b80..5dcccc421 100644 --- a/torchrec/inference/state_dict_transform.py +++ b/torchrec/inference/state_dict_transform.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Dict, List, Union import torch @@ -27,7 +29,10 @@ def state_dict_gather( for key, dst_tensor in dst.items(): src_tensor = src[key] if isinstance(src_tensor, ShardedTensor): - src_tensor.gather(out=dst_tensor if (dist.get_rank() == 0) else None) + src_tensor.gather( + out=dst_tensor if (dist.get_rank() == 0) else None, + dtype=dst_tensor.dtype, + ) elif isinstance(src_tensor, torch.Tensor): dst_tensor.copy_(src_tensor) else: diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py new file mode 100644 index 000000000..c8f7f0522 --- /dev/null +++ b/torchrec/inference/tests/test_inference.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + +import unittest +from argparse import Namespace +from typing import Any, cast, Dict, List + +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from torchrec import PoolingType +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.distributed.fused_params import ( + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP, + FUSED_PARAM_REGISTER_TBE_BOOL, +) +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder +from torchrec.distributed.test_utils.test_model import ( + ModelInput, + TestOverArchRegroupModule, + TestSparseNN, +) +from torchrec.distributed.types import ModuleSharder + +from torchrec.inference.dlrm_predict import ( + create_training_batch, + DLRMModelConfig, + DLRMPredictFactory, +) +from torchrec.inference.modules import ( + assign_weights_to_tbe, + DEFAULT_FUSED_PARAMS, + DEFAULT_SHARDERS, + get_table_to_weights_from_tbe, + quantize_inference_model, + set_pruning_data, + shard_quant_model, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class InferenceTest(unittest.TestCase): + def setUp(self) -> None: + num_features = 4 + num_weighted_features = 2 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + def test_dlrm_inference_package(self) -> None: + args = Namespace() + args.batch_size = 10 + args.num_embedding_features = 26 + args.num_dense_features = len(DEFAULT_INT_NAMES) + args.dense_arch_layer_sizes = "512,256,64" + args.over_arch_layer_sizes = "512,512,256,1" + args.sparse_feature_names = ",".join(DEFAULT_CAT_NAMES) + args.num_embeddings = 100_000 + args.num_embeddings_per_feature = ",".join( + [str(args.num_embeddings)] * args.num_embedding_features + ) + + batch = create_training_batch(args) + + model_config = DLRMModelConfig( + dense_arch_layer_sizes=list( + map(int, args.dense_arch_layer_sizes.split(",")) + ), + dense_in_features=args.num_dense_features, + embedding_dim=64, + id_list_features_keys=args.sparse_feature_names.split(","), + num_embeddings_per_feature=list( + map(int, args.num_embeddings_per_feature.split(",")) + ), + num_embeddings=args.num_embeddings, + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + sample_input=batch, + ) + + # Create torchscript model for inference + DLRMPredictFactory(model_config).create_predict_module( + world_size=1, device="cpu" + ) + + def test_regroup_module_inference(self) -> None: + set_propogate_device(True) + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + with torch.inference_mode(): + output = model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model(model) + quantized_output = quantized_model(local_batch[0]) + table_to_weight = get_table_to_weights_from_tbe(quantized_model) + + # Shard the model, all weights are initialized back to 0, so have to reassign weights + sharded_quant_model, _ = shard_quant_model( + quantized_model, + world_size=2, + compute_device="cpu", + sharding_device="cpu", + ) + assign_weights_to_tbe(quantized_model, table_to_weight) + + sharded_quant_output = sharded_quant_model(local_batch[0]) + + self.assertTrue(torch.allclose(output, quantized_output, atol=1e-4)) + self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-4)) + + def test_set_pruning_data(self) -> None: + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + pruning_dict = {} + + for table in self.tables: + pruning_dict[table.name] = table.num_embeddings - 1 + + set_pruning_data(model, pruning_dict) + quantized_model = quantize_inference_model(model) + + # Check EBC configs and TBE for correct shapes + for module in quantized_model.modules(): + if isinstance(module, EmbeddingBagCollection): + for config in module.embedding_bag_configs(): + if config.name in pruning_dict: + self.assertEqual( + config.num_embeddings_post_pruning, + pruning_dict[config.name], + ) + elif module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen": + for i, spec in enumerate(module.embedding_specs): + if spec[0] in pruning_dict: + self.assertEqual( + module.split_embedding_weights()[i][0].size(0), + pruning_dict[spec[0]], + ) + self.assertEqual( + spec[1], + pruning_dict[spec[0]], + ) + + def test_quantize_per_table_dtype(self) -> None: + max_feature_lengths = {} + + # First two tables as FPEBC + max_feature_lengths[self.tables[0].name] = 100 + max_feature_lengths[self.tables[1].name] = 100 + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + max_feature_lengths=max_feature_lengths, + ) + + per_table_dtype = {} + + for table in self.tables + self.weighted_tables: + # quint4x2 different than int8, which is default + per_table_dtype[table.name] = torch.quint4x2 + + quantized_model = quantize_inference_model( + model, per_table_weight_dtype=per_table_dtype + ) + + num_tbes = 0 + # Check EBC configs and TBE for correct shapes + for module in quantized_model.modules(): + if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen": + num_tbes += 1 + for i, spec in enumerate(module.embedding_specs): + self.assertEqual(spec[3], SparseType.INT4) + + # 3 TBES (1 FPEBC, 2 EBCs (1 weighted, 1 unweighted)) + + self.assertEqual(num_tbes, 3) + + def test_sharded_quantized_tbe_count(self) -> None: + set_propogate_device(True) + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + per_table_weight_dtypes = {} + + for table in self.tables + self.weighted_tables: + # quint4x2 different than int8, which is default + per_table_weight_dtypes[table.name] = ( + torch.quint4x2 if table.name == "table_0" else torch.quint8 + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + # with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types + output = model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model( + model, per_table_weight_dtype=per_table_weight_dtypes + ) + quantized_output = quantized_model(local_batch[0]) + table_to_weight = get_table_to_weights_from_tbe(quantized_model) + + # Shard the model, all weights are initialized back to 0, so have to reassign weights + sharded_quant_model, _ = shard_quant_model( + quantized_model, + world_size=1, + compute_device="cpu", + sharding_device="cpu", + ) + assign_weights_to_tbe(quantized_model, table_to_weight) + sharded_quant_output = sharded_quant_model(local_batch[0]) + + # When world_size = 1, we should have 1 TBE per sharded, quantized ebc + self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1) + self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1) + + # Check the weights are close + self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3)) + self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3)) + + # Check the sizes are correct + expected_num_embeddings = {} + + for table in self.tables: + expected_num_embeddings[table.name] = table.num_embeddings + + for module in quantized_model.modules(): + if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen": + for i, spec in enumerate(module.embedding_specs): + if spec[0] in expected_num_embeddings: + # We only expect the first table to be quantized to int4 due to test set up + if spec[0] == "table_0": + self.assertEqual(spec[3], SparseType.INT4) + else: + self.assertEqual(spec[3], SparseType.INT8) + + # Check sizes are equal + self.assertEqual( + module.split_embedding_weights()[i][0].size(0), + expected_num_embeddings[spec[0]], + ) + self.assertEqual( + spec[1], + expected_num_embeddings[spec[0]], + ) + + def test_sharded_quantized_lengths_to_tbe(self) -> None: + set_propogate_device(True) + + fused_params: Dict[str, Any] = {FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: True} + sharders: List[ModuleSharder[torch.nn.Module]] = [ + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingBagCollectionSharder(fused_params=fused_params), + ), + ] + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + # with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types + output = model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model(model) + quantized_output = quantized_model(local_batch[0]) + table_to_weight = get_table_to_weights_from_tbe(quantized_model) + + # Shard the model, all weights are initialized back to 0, so have to reassign weights + sharded_quant_model, _ = shard_quant_model( + quantized_model, + world_size=1, + compute_device="cpu", + sharding_device="cpu", + sharders=sharders, + ) + assign_weights_to_tbe(quantized_model, table_to_weight) + sharded_quant_output = sharded_quant_model(local_batch[0]) + + # When world_size = 1, we should have 1 TBE per sharded, quantized ebc + self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1) + self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1) + + # Check the weights are close + self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3)) + self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3)) + + def test_quantized_tbe_count_different_pooling(self) -> None: + set_propogate_device(True) + + self.tables[0].pooling = PoolingType.MEAN + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model(model) + # We should have 2 TBEs for unweighted ebc as the 2 tables here have different pooling types + self.assertTrue(len(quantized_model.sparse.ebc.tbes) == 2) + self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1) + # Changing this back + self.tables[0].pooling = PoolingType.SUM + + def test_fused_params_overwrite(self) -> None: + orig_value = DEFAULT_FUSED_PARAMS[FUSED_PARAM_REGISTER_TBE_BOOL] + + sharders = DEFAULT_SHARDERS + ebc_sharder = sharders[0] + ebc_fused_params = ebc_sharder.fused_params + ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = -1 + + ec_sharder = sharders[1] + ec_fused_params = ec_sharder.fused_params + + # Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params + self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value) + + # change it back to the original value because it modifies the global variable + # otherwise it will affect other tests + ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = orig_value diff --git a/torchrec/inference/tests/test_modules.py b/torchrec/inference/tests/test_modules.py deleted file mode 100644 index c7f757a3a..000000000 --- a/torchrec/inference/tests/test_modules.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -#!/usr/bin/env python3 -# @nolint - -import torch - - -class Simple(torch.nn.Module): - def __init__(self, N: int, M: int) -> None: - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(N, M)) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = self.weight + input - return output - - def set_weight(self, weight: torch.Tensor) -> None: - self.weight[:] = torch.nn.Parameter(weight) - - -class Nested(torch.nn.Module): - def __init__(self, N: int, M: int) -> None: - super().__init__() - self.simple = Simple(N, M) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return self.simple(input) diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py new file mode 100644 index 000000000..9f970cd6f --- /dev/null +++ b/torchrec/ir/schema.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from torchrec.modules.embedding_configs import DataType, PoolingType + + +# Same as EmbeddingBagConfig but serializable +@dataclass +class EmbeddingBagConfigMetadata: + num_embeddings: int + embedding_dim: int + name: str + data_type: DataType + feature_names: List[str] + weight_init_max: Optional[float] + weight_init_min: Optional[float] + need_pos: bool + pooling: PoolingType + + +@dataclass +class EBCMetadata: + tables: List[EmbeddingBagConfigMetadata] + is_weighted: bool + device: Optional[str] + + +@dataclass +class FPEBCMetadata: + is_fp_collection: bool + features: List[str] + + +@dataclass +class PositionWeightedModuleMetadata: + max_feature_length: int + + +@dataclass +class PositionWeightedModuleCollectionMetadata: + max_feature_lengths: List[Tuple[str, int]] + + +@dataclass +class KTRegroupAsDictMetadata: + groups: List[List[str]] + keys: List[str] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py new file mode 100644 index 000000000..28d355c0a --- /dev/null +++ b/torchrec/ir/serializer.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import json +from typing import Any, Dict, List, Optional, Type + +import torch +from torch import nn +from torchrec.ir.schema import ( + EBCMetadata, + EmbeddingBagConfigMetadata, + FPEBCMetadata, + KTRegroupAsDictMetadata, + PositionWeightedModuleCollectionMetadata, + PositionWeightedModuleMetadata, +) + +from torchrec.ir.types import SerializerInterface +from torchrec.ir.utils import logging +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorsCollection, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +def embedding_bag_config_to_metadata( + table_config: EmbeddingBagConfig, +) -> EmbeddingBagConfigMetadata: + return EmbeddingBagConfigMetadata( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=table_config.data_type.value, + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=table_config.pooling.value, + ) + + +def embedding_metadata_to_config( + table_config: EmbeddingBagConfigMetadata, +) -> EmbeddingBagConfig: + return EmbeddingBagConfig( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=DataType(table_config.data_type), + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=PoolingType(table_config.pooling), + ) + + +def get_deserialized_device( + config_device: Optional[str], device: Optional[torch.device] +) -> Optional[torch.device]: + if config_device: + original_device = torch.device(config_device) + if device is None: + device = original_device + elif original_device.type != device.type: + logger.warning( + f"deserialized device={device} overrides the original device={original_device}" + ) + return device + + +def ebc_meta_forward( + ebc: EmbeddingBagCollection, + features: KeyedJaggedTensor, +) -> KeyedTensor: + batch_size = features.stride() + dims = ebc._lengths_per_embedding + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` + outputs = torch.ops.torchrec.ir_emb_lookup(arg_list, batch_size, dims) + return KeyedTensor( + keys=ebc._embedding_names, + values=torch.cat(outputs, dim=1), + length_per_key=ebc._lengths_per_embedding, + ) + + +def fpebc_meta_forward( + fpebc: FeatureProcessedEmbeddingBagCollection, + features: KeyedJaggedTensor, +) -> KeyedTensor: + batch_size = features.stride() + ebc = fpebc._embedding_bag_collection + dims = ebc._lengths_per_embedding + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` + outputs = torch.ops.torchrec.ir_emb_lookup(arg_list, batch_size, dims) + return KeyedTensor( + keys=ebc._embedding_names, + values=torch.cat(outputs, dim=1), + length_per_key=ebc._lengths_per_embedding, + ) + + +def kt_regroup_meta_forward( + op_module: KTRegroupAsDict, keyed_tensors: List[KeyedTensor] +) -> Dict[str, torch.Tensor]: + lengths_dict: Dict[str, int] = {} + batch_size = keyed_tensors[0].values().size(0) + for kt in keyed_tensors: + for key, length in zip(kt.keys(), kt.length_per_key()): + lengths_dict[key] = length + out_lengths: List[int] = [0] * len(op_module._groups) + for i, group in enumerate(op_module._groups): + out_lengths[i] = sum(lengths_dict[key] for key in group) + arg_list = [kt.values() for kt in keyed_tensors] + outputs = torch.ops.torchrec.ir_kt_regroup(arg_list, batch_size, out_lengths) + return dict(zip(op_module._keys, outputs)) + + +class JsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using json. + """ + + module_to_serializer_cls: Dict[str, Type["JsonSerializer"]] = {} + _module_cls: Optional[Type[nn.Module]] = None + _children: Optional[List[str]] = None + + @classmethod + def children(cls, module: nn.Module) -> List[str]: + return [] if not cls._children else cls._children + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + raise NotImplementedError() + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + raise NotImplementedError() + + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + pass + + @classmethod + def encapsulate_module(cls, module: nn.Module) -> List[str]: + typename = type(module).__name__ + serializer = cls.module_to_serializer_cls.get(typename) + if serializer is None: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + assert issubclass(serializer, JsonSerializer) + assert serializer._module_cls is not None + if not isinstance(module, serializer._module_cls): + raise ValueError( + f"Expected module to be of type {serializer._module_cls.__name__}, " + f"got {type(module)}" + ) + metadata_dict = serializer.serialize_to_dict(module) + raw_dict = {"typename": typename, "metadata_dict": metadata_dict} + ir_metadata_tensor = torch.frombuffer( + json.dumps(raw_dict).encode(), dtype=torch.uint8 + ) + module.register_buffer("ir_metadata", ir_metadata_tensor, persistent=False) + serializer.swap_meta_forward(module) + return serializer.children(module) + + @classmethod + def decapsulate_module( + cls, module: nn.Module, device: Optional[torch.device] = None + ) -> nn.Module: + raw_bytes = module.get_buffer("ir_metadata").numpy().tobytes() + raw_dict = json.loads(raw_bytes.decode()) + typename = raw_dict["typename"] + metadata_dict = raw_dict["metadata_dict"] + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + serializer = cls.module_to_serializer_cls[typename] + assert issubclass(serializer, JsonSerializer) + module = serializer.deserialize_from_dict(metadata_dict, device, module) + + if serializer._module_cls is None: + raise ValueError( + "Must assign a nn.Module to class static variable _module_cls" + ) + if not isinstance(module, serializer._module_cls): + raise ValueError( + f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}" + ) + return module + + +class EBCJsonSerializer(JsonSerializer): + _module_cls = EmbeddingBagCollection + + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + assert isinstance(module, cls._module_cls) + # pyre-ignore + module.forward = ebc_meta_forward.__get__(module, cls._module_cls) + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + ebc_metadata = EBCMetadata( + tables=[ + embedding_bag_config_to_metadata(table_config) + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + for table_config in module.embedding_bag_configs() + ], + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + is_weighted=module.is_weighted(), + device=str(module.device), + ) + ebc_metadata_dict = ebc_metadata.__dict__ + ebc_metadata_dict["tables"] = [ + table_config.__dict__ for table_config in ebc_metadata_dict["tables"] + ] + return ebc_metadata_dict + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + tables = [ + EmbeddingBagConfigMetadata(**table_config) + for table_config in metadata_dict["tables"] + ] + + device = get_deserialized_device(metadata_dict.get("device"), device) + return EmbeddingBagCollection( + tables=[ + embedding_metadata_to_config(table_config) for table_config in tables + ], + is_weighted=metadata_dict["is_weighted"], + device=device, + ) + + +JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer + + +class PWMJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModule + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleMetadata( + max_feature_length=module.position_weight.shape[0], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleMetadata(**metadata_dict) + return PositionWeightedModule(metadata.max_feature_length, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModule"] = PWMJsonSerializer + + +class PWMCJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModuleCollection + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleCollectionMetadata( + max_feature_lengths=[ # convert to list of tuples to preserve the order + (feature, len) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `items`. + for feature, len in module.max_feature_lengths.items() + ], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleCollectionMetadata(**metadata_dict) + max_feature_lengths = { + feature: len for feature, len in metadata.max_feature_lengths + } + return PositionWeightedModuleCollection(max_feature_lengths, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModuleCollection"] = ( + PWMCJsonSerializer +) + + +class FPEBCJsonSerializer(JsonSerializer): + _module_cls = FeatureProcessedEmbeddingBagCollection + _children = ["_feature_processors", "_embedding_bag_collection"] + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + if isinstance(module._feature_processors, FeatureProcessorsCollection): + metadata = FPEBCMetadata( + is_fp_collection=True, + features=[], + ) + else: + metadata = FPEBCMetadata( + is_fp_collection=False, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `keys`. + features=list(module._feature_processors.keys()), + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = FPEBCMetadata(**metadata_dict) + assert unflatten_ep is not None + if metadata.is_fp_collection: + feature_processors = unflatten_ep._feature_processors + assert isinstance(feature_processors, FeatureProcessorsCollection) + else: + feature_processors: dict[str, FeatureProcessor] = {} + for feature in metadata.features: + fp = getattr(unflatten_ep._feature_processors, feature) + assert isinstance(fp, FeatureProcessor) + feature_processors[feature] = fp + ebc = unflatten_ep._embedding_bag_collection + assert isinstance(ebc, EmbeddingBagCollection) + return FeatureProcessedEmbeddingBagCollection( + ebc, + feature_processors, + ) + + +JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = ( + FPEBCJsonSerializer +) + + +class KTRegroupAsDictJsonSerializer(JsonSerializer): + _module_cls = KTRegroupAsDict + + @classmethod + def swap_meta_forward(cls, module: nn.Module) -> None: + assert isinstance(module, cls._module_cls) + # pyre-ignore + module.forward = kt_regroup_meta_forward.__get__(module, cls._module_cls) + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + metadata = KTRegroupAsDictMetadata( + # pyre-fixme[6]: For 1st argument expected `List[str]` but got + # `Union[Module, Tensor]`. + keys=module._keys, + # pyre-fixme[6]: For 2nd argument expected `List[List[str]]` but got + # `Union[Module, Tensor]`. + groups=module._groups, + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = KTRegroupAsDictMetadata(**metadata_dict) + return KTRegroupAsDict( + keys=metadata.keys, + groups=metadata.groups, + ) + + +JsonSerializer.module_to_serializer_cls["KTRegroupAsDict"] = ( + KTRegroupAsDictJsonSerializer +) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py new file mode 100644 index 000000000..31af19ec8 --- /dev/null +++ b/torchrec/ir/tests/test_serializer.py @@ -0,0 +1,749 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import copy +import unittest +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from torch import nn +from torchrec.ir.serializer import JsonSerializer + +from torchrec.ir.utils import ( + decapsulate_ir_modules, + encapsulate_ir_modules, + mark_dynamic_kjt, +) + +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class CompoundModule(nn.Module): + def __init__( + self, + ebc: EmbeddingBagCollection, + comp: Optional["CompoundModule"] = None, + mlist: List[Union[EmbeddingBagCollection, "CompoundModule"]] = [], + ) -> None: + super().__init__() + self.ebc = ebc + self.comp = comp + self.list = nn.ModuleList(mlist) + + def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + res = self.comp(features) if self.comp else [] + res.append(self.ebc(features).values()) + for m in self.list: + if isinstance(m, CompoundModule): + res.extend(m(features)) + else: + res.append(m(features).values()) + return res + + +class CompoundModuleSerializer(JsonSerializer): + _module_cls = CompoundModule + + @classmethod + def children(cls, module: nn.Module) -> List[str]: + children = ["ebc", "list"] + if module.comp is not None: + children += ["comp"] + return children + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + return {} + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + assert unflatten_ep is not None + ebc = unflatten_ep.ebc + comp = getattr(unflatten_ep, "comp", None) + i = 0 + mlist = [] + while hasattr(unflatten_ep.list, str(i)): + mlist.append(getattr(unflatten_ep.list, str(i))) + i += 1 + # pyre-fixme[6]: For 1st argument expected `EmbeddingBagCollection` but got + # `Union[Module, Tensor]`. + # pyre-fixme[6]: For 2nd argument expected `Optional[CompoundModule]` but + # got `Union[Module, Tensor]`. + return CompoundModule(ebc, comp, mlist) + + +class TestJsonSerializer(unittest.TestCase): + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict + def generate_model(self) -> nn.Module: + class Model(nn.Module): + def __init__(self, ebc, fpebc1, fpebc2): + super().__init__() + self.ebc1 = ebc + self.ebc2 = copy.deepcopy(ebc) + self.ebc3 = copy.deepcopy(ebc) + self.ebc4 = copy.deepcopy(ebc) + self.ebc5 = copy.deepcopy(ebc) + self.fpebc1 = fpebc1 + self.fpebc2 = fpebc2 + + def forward( + self, + features: KeyedJaggedTensor, + ) -> List[torch.Tensor]: + kt1 = self.ebc1(features) + kt2 = self.ebc2(features) + kt3 = self.ebc3(features) + kt4 = self.ebc4(features) + kt5 = self.ebc5(features) + + fpebc1_res = self.fpebc1(features) + fpebc2_res = self.fpebc2(features) + res: List[torch.Tensor] = [] + for kt in [kt1, kt2, kt3, kt4, kt5, fpebc1_res, fpebc2_res]: + res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()])) + return res + + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + + ebc = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + max_feature_lengths = {"f1": 100, "f2": 100} + + fpebc1 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb2_config], + is_weighted=True, + ), + PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + ), + ) + fpebc2 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=True, + ), + { + "f1": PositionWeightedModule(max_feature_length=10), + "f3": PositionWeightedModule(max_feature_length=20), + }, + ) + + model = Model(ebc, fpebc1, fpebc2) + + return model + + def generate_model_for_vbe_kjt(self) -> nn.Module: + class Model(nn.Module): + def __init__(self, ebc): + super().__init__() + self.ebc1 = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> List[torch.Tensor]: + kt1 = self.ebc1(features) + res: List[torch.Tensor] = [] + + for kt in [kt1]: + res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()])) + + return res + + config1 = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + config2 = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection( + tables=[config1, config2], + is_weighted=False, + ) + + model = Model(ebc) + + return model + + def test_serialize_deserialize_ebc(self) -> None: + model = self.generate_model() + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + eager_out = model(id_list_features) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + + for i, tensor in enumerate(ep_output): + self.assertEqual(eager_out[i].shape, tensor.shape) + + # Deserialize EBC + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + + # check EBC config + for i in range(5): + ebc_name = f"ebc{i + 1}" + self.assertIsInstance( + getattr(deserialized_model, ebc_name), EmbeddingBagCollection + ) + + for deserialized, orginal in zip( + getattr(deserialized_model, ebc_name).embedding_bag_configs(), + getattr(model, ebc_name).embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # Run forward on deserialized model and compare the output + deserialized_model.load_state_dict(model.state_dict()) + deserialized_out = deserialized_model(id_list_features) + + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) + + @unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.") + def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: + model = self.generate_model_for_vbe_kjt() + id_list_features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank=[[2], [1]], + inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])), + ) + + eager_out = model(id_list_features) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + + for i, tensor in enumerate(ep_output): + self.assertEqual(eager_out[i].shape, tensor.shape) + + # Deserialize EBC + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + + # check EBC config + for i in range(5): + ebc_name = f"ebc{i + 1}" + self.assertIsInstance( + getattr(deserialized_model, ebc_name), EmbeddingBagCollection + ) + + for deserialized, orginal in zip( + getattr(deserialized_model, ebc_name).embedding_bag_configs(), + getattr(model, ebc_name).embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # Run forward on deserialized model and compare the output + deserialized_model.load_state_dict(model.state_dict()) + deserialized_out = deserialized_model(id_list_features) + + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) + + def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None: + model = self.generate_model() + feature1 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + feature2 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), + ) + eager_out = model(feature2) + + # Serialize EBC + collection = mark_dynamic_kjt(feature1) + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (feature1,), + {}, + dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=tuple(sparse_fqns), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(feature2) + + # other asserts + for i, tensor in enumerate(ep_output): + self.assertEqual(eager_out[i].shape, tensor.shape) + + # Deserialize EBC + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model.load_state_dict(model.state_dict()) + + # Run forward on deserialized model + deserialized_out = deserialized_model(feature2) + + for i, tensor in enumerate(deserialized_out): + self.assertEqual(eager_out[i].shape, tensor.shape) + assert torch.allclose(eager_out[i], tensor) + + def test_ir_emb_lookup_device(self) -> None: + model = self.generate_model() + # pyre-fixme[16]: `Module` has no attribute `fpebc1`. + model.fpebc1 = copy.deepcopy(model.ebc1) + # pyre-fixme[16]: `Module` has no attribute `fpebc2`. + model.fpebc2 = copy.deepcopy(model.ebc1) + feature1 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + for device in ["cpu", "cuda", "meta"]: + if device == "cuda" and not torch.cuda.is_available(): + continue + device = torch.device(device) + outputs = model.to(device)(feature1.to(device)) + for output in outputs: + self.assertEqual(output.device.type, device.type) + + def test_deserialized_device(self) -> None: + model = self.generate_model() + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Deserialize EBC on different devices (, , ) + for device in ["cpu", "cuda", "meta"]: + if device == "cuda" and not torch.cuda.is_available(): + continue + device = torch.device(device) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_ep, JsonSerializer, device + ) + for name, m in deserialized_model.named_modules(): + if hasattr(m, "device"): + assert m.device.type == device.type, f"{name} should be on {device}" + for name, param in deserialized_model.named_parameters(): + # TODO: we don't support FPEBC yet, so we skip the FPEBC params + if "_feature_processors" in name: + continue + assert param.device.type == device.type, f"{name} should be on {device}" + + def test_compound_module(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=4, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=4, + num_embeddings=10, + feature_names=["f3"], + ) + ebc: Callable[[], EmbeddingBagCollection] = lambda: EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + + class MyModel(nn.Module): + def __init__(self, comp: CompoundModule) -> None: + super().__init__() + self.comp = comp + + def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + return self.comp(features) + + model = MyModel( + CompoundModule( + ebc=ebc(), + comp=CompoundModule(ebc(), CompoundModule(ebc(), mlist=[ebc(), ebc()])), + mlist=[ebc(), CompoundModule(ebc(), CompoundModule(ebc()))], + ) + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + eager_out = model(id_list_features) + + JsonSerializer.module_to_serializer_cls["CompoundModule"] = ( + CompoundModuleSerializer + ) + # Serialize + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + ep_output = ep.module()(id_list_features) + self.assertEqual(len(ep_output), len(eager_out)) + for x, y in zip(ep_output, eager_out): + self.assertEqual(x.shape, y.shape) + + # Deserialize + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + # Check if Compound Module is deserialized correctly + self.assertIsInstance(deserialized_model.comp, CompoundModule) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `comp`. + self.assertIsInstance(deserialized_model.comp.comp, CompoundModule) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `comp`. + self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `list`. + self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `list`. + self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule) + + deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model + deserialized_out = deserialized_model(id_list_features) + self.assertEqual(len(deserialized_out), len(eager_out)) + for x, y in zip(deserialized_out, eager_out): + self.assertTrue(torch.allclose(x, y)) + + def test_regroup_as_dict_module(self) -> None: + class Model(nn.Module): + def __init__(self, ebc, fpebc, regroup): + super().__init__() + self.ebc = ebc + self.fpebc = fpebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + kt1 = self.ebc(features) + kt2 = self.fpebc(features) + return self.regroup([kt1, kt2]) + + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f2"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f3", "f4"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f5"], + ) + + ebc = EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=False, + ) + max_feature_lengths = {"f3": 100, "f4": 100} + fpebc = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb2_config], + is_weighted=True, + ), + PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + ), + ) + regroup = KTRegroupAsDict([["f1", "f3", "f5"], ["f2", "f4"]], ["odd", "even"]) + model = Model(ebc, fpebc, regroup) + + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + self.assertFalse(model.regroup._is_inited) + + # Serialize EBC + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_is_inited`. + self.assertFalse(model.regroup._is_inited) + eager_out = model(id_list_features) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_is_inited`. + self.assertFalse(model.regroup._is_inited) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + for key in eager_out.keys(): + self.assertEqual(ep_output[key].shape, eager_out[key].shape) + # Deserialize EBC + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_is_inited`. + self.assertFalse(deserialized_model.regroup._is_inited) + deserialized_out = deserialized_model(id_list_features) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `_is_inited`. + self.assertTrue(deserialized_model.regroup._is_inited) + for key in eager_out.keys(): + self.assertEqual(deserialized_out[key].shape, eager_out[key].shape) + + def test_key_order_with_ebc_and_regroup(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + ebc1 = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + ebc2 = EmbeddingBagCollection( + tables=[tb1_config, tb3_config, tb2_config], + is_weighted=False, + ) + ebc2.load_state_dict(ebc1.state_dict()) + regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) + + class mySparse(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.ebc = ebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.regroup([self.ebc(features)]) + + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.sparse = mySparse(ebc, regroup) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.sparse(features) + + model = myModel(ebc1, regroup) + eager_out = model(id_list_features) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_ep, + JsonSerializer, + short_circuit_pytree_ebc_regroup=True, + finalize_interpreter_modules=True, + ) + + # we export the model with ebc1 and unflatten the model, + # and then swap with ebc2 (you can think this as the the sharding process + # resulting a shardedEBC), so that we can mimic the key-order change + # pyre-fixme[16]: `Module` has no attribute `ebc`. + # pyre-fixme[16]: `Tensor` has no attribute `ebc`. + deserialized_model.sparse.ebc = ebc2 + + deserialized_out = deserialized_model(id_list_features) + for key in eager_out.keys(): + torch.testing.assert_close(deserialized_out[key], eager_out[key]) diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py new file mode 100644 index 000000000..7dc1695b9 --- /dev/null +++ b/torchrec/ir/types.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import abc +from typing import Any, Dict, List, Optional + +import torch + +from torch import nn + + +class SerializerInterface(abc.ABC): + """ + Interface for Serializer classes for torch.export IR. + """ + + @classmethod + @property + def module_to_serializer_cls(cls) -> Dict[str, Any]: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def encapsulate_module(cls, module: nn.Module) -> List[str]: + # Take the eager embedding module and encapsulate the module, including serialization + # and meta_forward-swapping, then returns a list of children (fqns) which needs further encapsulation + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def decapsulate_module( + cls, module: nn.Module, device: Optional[torch.device] = None + ) -> nn.Module: + # Take the eager embedding module and decapsulate it by removing serialization and meta_forward-swapping + raise NotImplementedError diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py new file mode 100644 index 000000000..6e9367e43 --- /dev/null +++ b/torchrec/ir/utils.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import logging +import operator +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from torch import nn +from torch.export import Dim, ShapesCollection +from torch.export.dynamic_shapes import _Dim as DIM +from torch.export.unflatten import InterpreterModule +from torch.fx import Node +from torchrec.ir.types import SerializerInterface +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +# TODO: Replace the default interface with the python dataclass interface +DEFAULT_SERIALIZER_CLS = SerializerInterface +DYNAMIC_DIMS: Dict[str, int] = defaultdict(int) +logger: logging.Logger = logging.getLogger(__name__) + + +def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]: + """ + Returns the device of the first non-None tensor in the list. + """ + for t in tensors: + if t is not None: + return t.device + return None + + +@torch.library.custom_op("torchrec::ir_emb_lookup", mutates_args={}) +def ir_emb_lookup_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"torch.ops.torchrec.ir_emb_lookup -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +@torch.library.register_fake("torchrec::ir_emb_lookup") +def ir_emb_lookup_fake( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"ir_emb_lookup_fake -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +@torch.library.custom_op("torchrec::ir_kt_regroup", mutates_args={}) +def ir_kt_regroup_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"torch.ops.torchrec.ir_kt_regroup -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +@torch.library.register_fake("torchrec::ir_kt_regroup") +def ir_kt_regroup_fake( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"ir_kt_regroup_fake -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +@torch.library.custom_op("torchrec::ir_dynamic_batch_emb_lookup", mutates_args={}) +def ir_dynamic_batch_emb_lookup_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info( + f"torch.ops.torchrec.ir_dynamic_batch_emb_lookup -> ({batch_size}, {dims}) {device}" + ) + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +@torch.library.register_fake("torchrec::ir_dynamic_batch_emb_lookup") +def ir_dynamic_batch_emb_lookup_fake( + tensors: List[Optional[torch.Tensor]], batch_dize: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + batch_size = torch.library.get_ctx().new_dynamic_size() + logger.info(f"ir_dynamic_batch_emb_lookup_fake -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] + + +def encapsulate_ir_modules( + module: nn.Module, + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + fqn: str = "", +) -> Tuple[nn.Module, List[str]]: + """ + Takes a module and encapsulate its embedding modules and serializes them to the module buffer. + Returns the modified module and a list of fqns that had the buffer added, which is needed for torch.export + The encapsulation is done by using meta_forward function provided by the serializer + to replace the module's original forward function. + """ + preserve_fqns: List[str] = [] # fqns of the serialized modules + children: List[str] = [] # fqns of the children that need further serialization + # handle current module, and find the children which need further serialization + if type(module).__name__ in serializer.module_to_serializer_cls: + children = serializer.encapsulate_module(module) + preserve_fqns.append(fqn) + else: + # if the module is not of type serializer, then we check all its children + children = [child for child, _ in module.named_children()] + + # handle child modules recursively + for child in children: + submodule = module.get_submodule(child) + child_fqn = f"{fqn}.{child}" if len(fqn) > 0 else child + _, fqns = encapsulate_ir_modules(submodule, serializer, child_fqn) + preserve_fqns.extend(fqns) + return module, preserve_fqns + + +def decapsulate_ir_modules( + module: nn.Module, + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + device: Optional[torch.device] = None, + finalize_interpreter_modules: bool = False, + short_circuit_pytree_ebc_regroup: bool = False, +) -> nn.Module: + """ + Takes a module and decapsulate its embedding modules by retrieving the buffer. + Returns the module with restored embedding (sub) modules. + """ + for child_fqn, child in module.named_children(): + # perform deserialization on the children first, so that we can replace the child module with + # the deserialized module, and then replace it in the parent + child = decapsulate_ir_modules( + module=child, serializer=serializer, device=device + ) + # replace the child module with deserialized one if applicable + setattr(module, child_fqn, child) + + # only deserialize if the module has ir_metadata buffer, otherwise return as is + # we use "ir_metadata" as a convention to identify the deserializable module + if "ir_metadata" in dict(module.named_buffers()): + module = serializer.decapsulate_module(module, device) + + if short_circuit_pytree_ebc_regroup: + module = _short_circuit_pytree_ebc_regroup(module) + assert finalize_interpreter_modules, "need finalize_interpreter_modules=True" + + if finalize_interpreter_modules: + for mod in module.modules(): + if isinstance(mod, InterpreterModule): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + mod.finalize() + + return module + + +def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM: + """ + Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name. + """ + dim = f"{name}_{DYNAMIC_DIMS[name]}" + DYNAMIC_DIMS[name] += 1 + return Dim(dim, min=min, max=max) + + +def mark_dynamic_kjt( + kjt: KeyedJaggedTensor, + shapes_collection: Optional[ShapesCollection] = None, + variable_length: bool = False, + vlen: Optional[DIM] = None, + llen: Optional[DIM] = None, +) -> ShapesCollection: + """ + Makes the given KJT dynamic. If it's not variable length, it will only have + one dynamic dimension, which is the length of the values (and weights). + If it is variable length, then the lengths and offsets will be dynamic. + + If a shapes collection is provided, it will be updated with the new shapes, + otherwise a new shapes collection will be created. A passed-in shapes_collection is + useful if you have multiple KJTs or other dynamic shapes that you want to trace. + + If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise, + it will use the default name "vlen" for values, and "llen", "lofs" if variable length. + A passed-in dynamic dim is useful if the dynamic dim is already used in other places. + + Args: + kjt (KeyedJaggedTensor): The KJT to make dynamic. + shapes_collection (Optional[ShapesCollection]): The collection to update. + variable_length (bool): Whether the KJT is variable length. + vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen". + llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen". + batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true. + """ + + def _has_dim(t: Optional[torch.Tensor]) -> bool: + return t is not None and t.dim() > 0 + + if shapes_collection is None: + shapes_collection = ShapesCollection() + vlen = _get_dim("vlen") if vlen is None else vlen + + if _has_dim(kjt._values): + if kjt._values.numel() == 0: + # if the values is empty, we need to set the shape to (2,) to make it compatible with dynamic shape + # a 0-size dynamic shape will cause error in torch.export. + # logically when the values is empty, the lengths and offsets should all be zero-value tensors. + # And this makes the actual values irrelavent to the downstream process. + kjt._values = torch.ones( + 2, device=kjt._values.device, dtype=kjt._values.dtype + ) + shapes_collection[kjt._values] = (vlen,) + if _has_dim(kjt._weights): + shapes_collection[kjt._weights] = (vlen,) + if variable_length: + llen = _get_dim("llen") if llen is None else llen + if _has_dim(kjt._lengths): + shapes_collection[kjt._lengths] = (llen,) + if _has_dim(kjt._offsets): + shapes_collection[kjt._offsets] = (llen + 1,) + return shapes_collection + + +def move_to_copy_nodes_to_device( + unflattened_module: nn.Module, + device: torch.device, +) -> nn.Module: + """ + Moves all the copy nodes to the given device. + """ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `nodes`. + for nodes in unflattened_module.graph.nodes: + if "_to_copy" in nodes.name: + new_kwargs = {} + for k, v in nodes.kwargs.items(): + if isinstance(v, torch.device): + v = device + new_kwargs[k] = v + nodes.kwargs = new_kwargs + + return unflattened_module + + +def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module: + """ + Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue. + https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/ + EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict + """ + ebc_fqns: List[str] = [] + regroup_fqns: List[str] = [] + for fqn, m in module.named_modules(): + if isinstance(m, FeatureProcessedEmbeddingBagCollection): + ebc_fqns.append(fqn) + elif isinstance(m, EmbeddingBagCollection): + if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]): + continue + ebc_fqns.append(fqn) + elif isinstance(m, KTRegroupAsDict): + regroup_fqns.append(fqn) + if len(ebc_fqns) == len(regroup_fqns) == 0: + # nothing happens if there is no EBC or KTRegroupAsDict (e.g., the PEA case) + return module + elif len(regroup_fqns) == 0: + # model only contains EBCs, KT (from EBC) pytree.flatten has performance impact + logger.warning( + "Expect perf impact if KTRegroupAsDict is not used together with EBCs." + ) + return module + elif len(ebc_fqns) == 0: + # model only contains KTRegroupAsDict, KTs are not from EBC, need to be careful + logger.warning("KTRegroupAsDict is not from EBC, need to be careful.") + return module + else: + return prune_pytree_flatten_unflatten( + module, in_fqns=regroup_fqns, out_fqns=ebc_fqns + ) + + +def prune_pytree_flatten_unflatten( + module: nn.Module, in_fqns: List[str], out_fqns: List[str] +) -> nn.Module: + """ + Remove pytree flatten and unflatten function between the given in_fqns and out_fqns. + "preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs] + [tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module" + """ + + def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `nodes`. + for node in mod.graph.nodes: + if node.op == "call_module" and node.target == fqn: + return mod, node + assert "." in fqn, f"can't find {fqn} in the graph of {mod}" + curr, fqn = fqn.split(".", maxsplit=1) + mod = getattr(mod, curr) + return _get_graph_node(mod, fqn) + + # remove tree_unflatten from the in_fqns (in-coming nodes) + for fqn in in_fqns: + submodule, node = _get_graph_node(module, fqn) + assert len(node.args) == 1 + getitem_getitem: Node = node.args[0] # pyre-ignore[9] + assert ( + getitem_getitem.op == "call_function" + and getitem_getitem.target == operator.getitem + ) + tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16] + assert ( + tree_unflatten_getitem.op == "call_function" + and tree_unflatten_getitem.target == operator.getitem + ) + tree_unflatten = tree_unflatten_getitem.args[0] + assert ( + tree_unflatten.op == "call_function" + and tree_unflatten.target == torch.utils._pytree.tree_unflatten + ) + logger.info(f"Removing tree_unflatten from {fqn}") + input_nodes = tree_unflatten.args[0] + node.args = (input_nodes,) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `eliminate_dead_code`. + submodule.graph.eliminate_dead_code() + + # remove tree_flatten_spec from the out_fqns (out-going nodes) + for fqn in out_fqns: + submodule, node = _get_graph_node(module, fqn) + users = list(node.users.keys()) + assert ( + len(users) == 1 + and users[0].op == "call_function" + and users[0].target == torch.fx._pytree.tree_flatten_spec + ) + tree_flatten_users = list(users[0].users.keys()) + assert ( + len(tree_flatten_users) == 1 + and tree_flatten_users[0].op == "call_function" + and tree_flatten_users[0].target == operator.getitem + ) + logger.info(f"Removing tree_flatten_spec from {fqn}") + getitem_node = tree_flatten_users[0] + getitem_node.replace_all_uses_with(node) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `eliminate_dead_code`. + submodule.graph.eliminate_dead_code() + return module diff --git a/torchrec/linter/module_linter.py b/torchrec/linter/module_linter.py index 0c3a097a4..6ce79ed0c 100644 --- a/torchrec/linter/module_linter.py +++ b/torchrec/linter/module_linter.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import ast import json from argparse import ArgumentParser, Namespace @@ -35,7 +37,9 @@ def print_error_message( """ lint_item = { "path": python_path, + # pyre-fixme[16]: `AST` has no attribute `lineno`. "line": node.lineno, + # pyre-fixme[16]: `AST` has no attribute `col_offset`. "char": node.col_offset + 1, "severity": severity, "name": name, diff --git a/torchrec/linter/tests/test_module_linter.py b/torchrec/linter/tests/test_module_linter.py index 3c12c9c53..f87191c5a 100644 --- a/torchrec/linter/tests/test_module_linter.py +++ b/torchrec/linter/tests/test_module_linter.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from unittest.mock import patch diff --git a/torchrec/metrics/README.md b/torchrec/metrics/README.md index 6a08f20f6..c33b6bc08 100644 --- a/torchrec/metrics/README.md +++ b/torchrec/metrics/README.md @@ -193,7 +193,7 @@ ne = NEMetric( window_size=512, fused_update_limit=0, ) -labels, predictions, weights = parse_task_model_outputs(tasks, model_output) +labels, predictions, weights, _ = parse_task_model_outputs(tasks, model_output) ne.update( predictions=predictions, labels=labels, diff --git a/torchrec/metrics/accuracy.py b/torchrec/metrics/accuracy.py new file mode 100644 index 000000000..95537aab3 --- /dev/null +++ b/torchrec/metrics/accuracy.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +THRESHOLD = "threshold" + + +def compute_accuracy( + accuracy_sum: torch.Tensor, weighted_num_samples: torch.Tensor +) -> torch.Tensor: + return torch.where( + weighted_num_samples == 0.0, 0.0, accuracy_sum / weighted_num_samples + ).double() + + +def compute_accuracy_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * ((predictions >= threshold) == labels), dim=-1) + + +def get_accuracy_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], + threshold: float = 0.5, +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(predictions) + return { + "accuracy_sum": compute_accuracy_sum(labels, predictions, weights, threshold), + "weighted_num_samples": torch.sum(weights, dim=-1), + } + + +class AccuracyMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Accuracy. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + threshold (float): If provided, computes accuracy metrics cutting off at + the specified threshold. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "accuracy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._threshold: float = threshold + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None: + raise RecMetricException( + "Inputs 'predictions' should not be None for AccuracyMetricComputation update" + ) + states = get_accuracy_states(labels, predictions, weights, self._threshold) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.ACCURACY, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_accuracy( + cast(torch.Tensor, self.accuracy_sum), + cast(torch.Tensor, self.weighted_num_samples), + ), + ), + MetricComputationReport( + name=MetricName.ACCURACY, + metric_prefix=MetricPrefix.WINDOW, + value=compute_accuracy( + self.get_window_state("accuracy_sum"), + self.get_window_state("weighted_num_samples"), + ), + ), + ] + return reports + + +class AccuracyMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.ACCURACY + _computation_class: Type[RecMetricComputation] = AccuracyMetricComputation diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 77150b26a..5026ee1e2 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -5,9 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, cast, List, Optional, Type +# pyre-strict + +import logging +from functools import partial +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type import torch +import torch.distributed as dist +from torchmetrics.utilities.distributed import gather_all_tensors +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( MetricComputationReport, @@ -16,34 +23,144 @@ RecMetricException, ) + +logger: logging.Logger = logging.getLogger(__name__) + PREDICTIONS = "predictions" LABELS = "labels" WEIGHTS = "weights" +GROUPING_KEYS = "grouping_keys" +REQUIRED_INPUTS = "required_inputs" + + +def _concat_if_needed( + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This check exists because of how the state is organized due to quirks in RecMetrics. + Since we do not do tensor concatenatation in the compute or update call, there are cases (in non-distributed settings) + where the tensors from updates are not concatted into a single tensor. Which is determined by the length of the list. + """ + preds_t, labels_t, weights_t = None, None, None + if len(predictions) > 1: + preds_t = torch.cat(predictions, dim=-1) + labels_t = torch.cat(labels, dim=-1) + weights_t = torch.cat(weights, dim=-1) + else: + preds_t = predictions[0] + labels_t = labels[0] + weights_t = weights[0] + + return preds_t, labels_t, weights_t + + +def _compute_auc_helper( + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + apply_bin: bool = False, +) -> torch.Tensor: + sorted_indices = torch.argsort(predictions, descending=True, dim=-1) + sorted_labels = torch.index_select(labels, dim=0, index=sorted_indices) + if apply_bin: + # TODO - [add flag to set bining dyamically] for use with soft labels, >=0.039 --> 1, <0.039 --> 0 + sorted_labels = torch.ge(sorted_labels, 0.039).to(dtype=sorted_labels.dtype) + sorted_weights = torch.index_select(weights, dim=0, index=sorted_indices) + cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0) + cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0) + auc = torch.where( + cum_fp[-1] * cum_tp[-1] == 0, + 0.5, # 0.5 is the no-signal default value for auc. + torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1], + ) + return auc def compute_auc( - n_tasks: int, predictions: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor + n_tasks: int, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], + apply_bin: bool = False, ) -> torch.Tensor: - # The return values are sorted_predictions, sorted_index but only - # sorted_predictions is needed. - _, sorted_indices = torch.sort(predictions, descending=True, dim=-1) + """ + Computes AUC (Area Under the Curve) for binary classification. + + Args: + n_tasks (int): number of tasks. + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + """ + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) aucs = [] - for sorted_indices_i, labels_i, weights_i in zip(sorted_indices, labels, weights): - sorted_labels = torch.index_select(labels_i, dim=0, index=sorted_indices_i) - sorted_weights = torch.index_select(weights_i, dim=0, index=sorted_indices_i) - cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0) - cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0) - auc = torch.where( - cum_fp[-1] * cum_tp[-1] == 0, - 0.5, # 0.5 is the no-signal default value for auc. - torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1], - ) + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): + auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) aucs.append(auc.view(1)) return torch.cat(aucs) -def _state_reduction(state: List[torch.Tensor]) -> List[torch.Tensor]: - return [torch.cat(state, dim=1)] +def compute_auc_per_group( + n_tasks: int, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], + grouping_keys: torch.Tensor, +) -> torch.Tensor: + """ + Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels. + Args: + n_tasks (int): number of tasks + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples) + labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples) + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples) + grouping_keys (torch.Tensor): tensor of size (n_examples,) + + Returns: + torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. + """ + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) + aucs = [] + if grouping_keys.numel() != 0 and grouping_keys[0] == -1: + # we added padding as the first elements during init to avoid floating point exception in sync() + # removing the paddings to avoid numerical errors. + grouping_keys = grouping_keys[1:] + + # get unique group indices + group_indices = torch.unique(grouping_keys) + + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): + # Loop over each group + auc_groups_sum = torch.tensor([0], dtype=torch.float32) + for group_idx in group_indices: + # get predictions, labels, and weights for this group + group_mask = grouping_keys == group_idx + grouped_predictions = predictions_i[group_mask] + grouped_labels = labels_i[group_mask] + grouped_weights = weights_i[group_mask] + + auc = _compute_auc_helper( + grouped_predictions, grouped_labels, grouped_weights + ) + auc_groups_sum = auc_groups_sum.to(auc.device) + auc_groups_sum += auc.view(1) + avg_auc = ( + auc_groups_sum / len(group_indices) + if len(group_indices) > 0 + else torch.tensor([0.5], dtype=torch.float32) + ) + aucs.append(avg_auc) + return torch.cat(aucs) + + +def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]: + return [torch.cat(state, dim=dim)] + + +# pyre-ignore +_grouping_keys_state_reduction = partial(_state_reduction, dim=0) class AUCMetricComputation(RecMetricComputation): @@ -52,10 +169,29 @@ class AUCMetricComputation(RecMetricComputation): The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail. + Args: + grouped_auc (bool): If True, computes AUC per group and returns average AUC across all groups. + The `grouping_keys` is provided during state updates along with predictions, labels, weights. + This feature is currently not enabled for `fused_update_limit`. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + *args: Any, + grouped_auc: bool = False, + apply_bin: bool = False, + fused_update_limit: int = 0, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) + if grouped_auc and fused_update_limit > 0: + raise RecMetricException( + "Grouped AUC and Fused Update Limit cannot be enabled together yet." + ) + + self._grouped_auc: bool = grouped_auc + self._apply_bin: bool = apply_bin + self._num_samples: int = 0 self._add_state( PREDICTIONS, [], @@ -77,6 +213,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: dist_reduce_fx=_state_reduction, persistent=False, ) + if self._grouped_auc: + self._add_state( + GROUPING_KEYS, + [], + add_window_state=False, + dist_reduce_fx=_grouping_keys_state_reduction, + persistent=False, + ) self._init_states() # The states values are set to empty lists in __init__() and reset(), and then we @@ -90,16 +234,18 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def _init_states(self) -> None: if len(getattr(self, PREDICTIONS)) > 0: return - + self._num_samples = 0 getattr(self, PREDICTIONS).append( - torch.zeros((self._n_tasks, 1), dtype=torch.double, device=self.device) + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) getattr(self, LABELS).append( - torch.zeros((self._n_tasks, 1), dtype=torch.double, device=self.device) + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) getattr(self, WEIGHTS).append( - torch.zeros((self._n_tasks, 1), dtype=torch.double, device=self.device) + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) ) + if self._grouped_auc: + getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) def update( self, @@ -107,47 +253,122 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: + """ + Args: + predictions (torch.Tensor): tensor of size (n_task, n_examples) + labels (torch.Tensor): tensor of size (n_task, n_examples) + weights (torch.Tensor): tensor of size (n_task, n_examples) + grouping_key (torch.Tensor): Optional tensor of size (1, n_examples) that specifies the groups of + predictions/labels per batch. If provided, the AUC metric also + computes AUC per group and returns the average AUC across all groups. + """ if predictions is None or weights is None: raise RecMetricException( "Inputs 'predictions' and 'weights' should not be None for AUCMetricComputation update" ) - predictions = predictions.double() - labels = labels.double() - weights = weights.double() - num_samples = getattr(self, PREDICTIONS)[0].size(-1) + predictions = predictions.float() + labels = labels.float() + weights = weights.float() batch_size = predictions.size(-1) - start_index = max(num_samples + batch_size - self._window_size, 0) + start_index = max(self._num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], - predictions, - ], - dim=-1, - ) - getattr(self, LABELS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], - dim=-1, - ) - getattr(self, WEIGHTS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], - dim=-1, - ) + w_preds = getattr(self, PREDICTIONS) + w_labels = getattr(self, LABELS) + w_weights = getattr(self, WEIGHTS) + + # remove init states + if self._num_samples == 0: + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + + w_preds.append(predictions) + w_labels.append(labels) + w_weights.append(weights) + + self._num_samples += batch_size + + while self._num_samples > self._window_size: + diff = self._num_samples - self._window_size + if diff > w_preds[0].size(-1): + self._num_samples -= w_preds[0].size(-1) + # Remove the first element from predictions, labels, and weights + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + else: + # Update the first element of predictions, labels, and weights + # Off by one potentially - keeping legacy behaviour + for lst in [w_preds, w_labels, w_weights]: + lst[0] = lst[0][:, diff:] + # if empty tensor, remove it + if torch.numel(lst[0]) == 0: + lst.pop(0) + self._num_samples -= diff + + if self._grouped_auc: + if REQUIRED_INPUTS not in kwargs or ( + (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None + ): + raise RecMetricException( + f"Input '{GROUPING_KEYS}' are required for AUCMetricComputation grouped update" + ) + getattr(self, GROUPING_KEYS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:], + grouping_keys.squeeze(), + ], + dim=0, + ) def _compute(self) -> List[MetricComputationReport]: - return [ + reports = [] + reports.append( MetricComputationReport( name=MetricName.AUC, metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + self._apply_bin, ), ) - ] + ) + + if self._grouped_auc: + reports.append( + MetricComputationReport( + name=MetricName.GROUPED_AUC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_auc_per_group( + self._n_tasks, + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0], + ), + ) + ) + return reports + + def _sync_dist( + self, + dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24] + process_group: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + """ + This function is overridden from torchmetric.Metric, since for AUC we want to concat the tensors + right before the allgather collective is called. It directly changes the attributes/states, which + is ok because end of function sets the attributes to reduced values + """ + for attr in self._reductions: # pragma: no cover + val = getattr(self, attr) + if isinstance(val, list) and len(val) > 1: + setattr(self, attr, [torch.cat(val, dim=-1)]) + super()._sync_dist(dist_sync_fn, process_group) def reset(self) -> None: super().reset() @@ -157,3 +378,38 @@ def reset(self) -> None: class AUCMetric(RecMetric): _namespace: MetricNamespace = MetricNamespace.AUC _computation_class: Type[RecMetricComputation] = AUCMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + if kwargs.get("grouped_auc"): + self._required_inputs.add(GROUPING_KEYS) + if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION: + logging.warning( + f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet " + "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect." + ) diff --git a/torchrec/metrics/auprc.py b/torchrec/metrics/auprc.py new file mode 100644 index 000000000..ed99417d2 --- /dev/null +++ b/torchrec/metrics/auprc.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from functools import partial +from typing import Any, cast, Dict, List, Optional, Type + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +PREDICTIONS = "predictions" +LABELS = "labels" +WEIGHTS = "weights" +GROUPING_KEYS = "grouping_keys" +REQUIRED_INPUTS = "required_inputs" + + +def _riemann_integral(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Riemann integral approximates the area of each cell with a rectangle positioned at the egde. + It is conventionally used rather than trapezoid approximation, which uses a rectangle positioned in the + center""" + return -torch.sum((x[1:] - x[:-1]) * y[:-1]) + + +def _compute_auprc_helper( + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + sorted_indices = torch.argsort(predictions, descending=True, dim=-1) + + threshold = torch.index_select(predictions, dim=0, index=sorted_indices) + + sorted_labels = torch.index_select(labels, dim=0, index=sorted_indices) + + sorted_weights = torch.index_select(weights, dim=0, index=sorted_indices) + + mask = F.pad(threshold.diff(dim=0) != 0, [0, 1], value=1.0) + num_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0)[mask] + num_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0)[mask] + + precision = (num_tp / (num_tp + num_fp)).flip(0) + recall = (num_tp / num_tp[-1]).flip(0) + + # The last precision and recall values are 1.0 and 0.0 without a corresponding threshold. + # This ensures that the graph starts on the y-axis. + precision = torch.cat([precision, precision.new_ones(1)]) + recall = torch.cat([recall, recall.new_zeros(1)]) + + # If recalls are NaNs, set NaNs to 1.0s. + if torch.isnan(recall[0]): + recall = torch.nan_to_num(recall, 1.0) + + auprc = _riemann_integral(recall, precision) + return auprc + + +def compute_auprc( + n_tasks: int, + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + """ + Computes AUPRC (Area Under the Curve) for binary classification. + + Args: + n_tasks (int): number of tasks. + predictions (torch.Tensor): tensor of size (n_tasks, n_examples). + labels (torch.Tensor): tensor of size (n_tasks, n_examples). + weights (torch.Tensor): tensor of size (n_tasks, n_examples). + """ + auprcs = [] + for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): + auprc = _compute_auprc_helper(predictions_i, labels_i, weights_i) + auprcs.append(auprc.view(1)) + return torch.cat(auprcs) + + +def compute_auprc_per_group( + n_tasks: int, + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + grouping_keys: torch.Tensor, +) -> torch.Tensor: + """ + Computes AUPRC (Area Under the Curve) for binary classification for groups of predictions/labels. + Args: + n_tasks (int): number of tasks + predictions (torch.Tensor): tensor of size (n_tasks, n_examples) + labels (torch.Tensor): tensor of size (n_tasks, n_examples) + weights (torch.Tensor): tensor of size (n_tasks, n_examples) + grouping_keys (torch.Tensor): tensor of size (n_examples,) + + Returns: + torch.Tensor: tensor of size (n_tasks,), average of AUPRCs per group. + """ + auprcs = [] + if grouping_keys.numel() != 0 and grouping_keys[0] == -1: + # we added padding as the first elements during init to avoid floating point exception in sync() + # removing the paddings to avoid numerical errors. + grouping_keys = grouping_keys[1:] + predictions = predictions[:, 1:] + labels = labels[:, 1:] + weights = weights[:, 1:] + + # get unique group indices + group_indices = torch.unique(grouping_keys) + + for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): + # Loop over each group + auprc_groups_sum = torch.tensor([0], dtype=torch.float32) + for group_idx in group_indices: + # get predictions, labels, and weights for this group + group_mask = grouping_keys == group_idx + grouped_predictions = predictions_i[group_mask] + grouped_labels = labels_i[group_mask] + grouped_weights = weights_i[group_mask] + + auprc = _compute_auprc_helper( + grouped_predictions, grouped_labels, grouped_weights + ) + auprc_groups_sum = auprc_groups_sum.to(auprc.device) + auprc_groups_sum += auprc.view(1) + avg_auprc = ( + auprc_groups_sum / len(group_indices) + if len(group_indices) > 0 + else torch.tensor([0.5], dtype=torch.float32) + ) + auprcs.append(avg_auprc) + return torch.cat(auprcs) + + +def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]: + return [torch.cat(state, dim=dim)] + + +# pyre-ignore +_grouping_keys_state_reduction = partial(_state_reduction, dim=0) + + +class AUPRCMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for AUPRC, i.e. Area Under the Curve. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + Args: + grouped_auprc (bool): If True, computes AUPRC per group and returns average AUPRC across all groups. + The `grouping_keys` is provided during state updates along with predictions, labels, weights. + This feature is currently not enabled for `fused_update_limit`. + """ + + def __init__( + self, + *args: Any, + grouped_auprc: bool = False, + fused_update_limit: int = 0, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if grouped_auprc and fused_update_limit > 0: + raise RecMetricException( + "Grouped AUPRC and Fused Update Limit cannot be enabled together yet." + ) + + self._grouped_auprc: bool = grouped_auprc + self._add_state( + PREDICTIONS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + self._add_state( + LABELS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + self._add_state( + WEIGHTS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + if self._grouped_auprc: + self._add_state( + GROUPING_KEYS, + [], + add_window_state=False, + dist_reduce_fx=_grouping_keys_state_reduction, + persistent=False, + ) + self._init_states() + + # The states values are set to empty lists in __init__() and reset(), and then we + # add a size (self._n_tasks, 1) tensor to each of the list as the initial values + # This is to bypass the limitation of state aggregation in TorchMetrics sync() when + # we try to checkpoint the states before update() + # The reason for using lists here is to avoid automatically stacking the tensors from + # all the trainers into one tensor in sync() + # The reason for using non-empty tensors as the first elements is to avoid the + # floating point exception thrown in sync() for aggregating empty tensors + def _init_states(self) -> None: + if len(getattr(self, PREDICTIONS)) > 0: + return + + getattr(self, PREDICTIONS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + getattr(self, LABELS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + getattr(self, WEIGHTS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + if self._grouped_auprc: + getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + """ + Args: + predictions (torch.Tensor): tensor of size (n_task, n_examples) + labels (torch.Tensor): tensor of size (n_task, n_examples) + weights (torch.Tensor): tensor of size (n_task, n_examples) + grouping_key (torch.Tensor): Optional tensor of size (1, n_examples) that specifies the groups of + predictions/labels per batch. If provided, the PR AUC metric also + computes PR AUC per group and returns the average PR AUC across all groups. + """ + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for AUPRCMetricComputation update" + ) + predictions = predictions.float() + labels = labels.float() + weights = weights.float() + num_samples = getattr(self, PREDICTIONS)[0].size(-1) + batch_size = predictions.size(-1) + start_index = max(num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. + getattr(self, PREDICTIONS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], + predictions, + ], + dim=-1, + ) + getattr(self, LABELS)[0] = torch.cat( + [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], + dim=-1, + ) + getattr(self, WEIGHTS)[0] = torch.cat( + [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], + dim=-1, + ) + if self._grouped_auprc: + if REQUIRED_INPUTS not in kwargs or ( + (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None + ): + raise RecMetricException( + f"Input '{GROUPING_KEYS}' are required for AUPRCMetricComputation grouped update" + ) + getattr(self, GROUPING_KEYS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:], + grouping_keys.squeeze(), + ], + dim=0, + ) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.AUPRC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_auprc( + self._n_tasks, + cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), + cast(torch.Tensor, getattr(self, LABELS)[0]), + cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + ), + ) + ] + if self._grouped_auprc: + reports.append( + MetricComputationReport( + name=MetricName.GROUPED_AUPRC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_auprc_per_group( + self._n_tasks, + cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), + cast(torch.Tensor, getattr(self, LABELS)[0]), + cast(torch.Tensor, getattr(self, WEIGHTS)[0]), + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), + ), + ) + ) + return reports + + def reset(self) -> None: + super().reset() + self._init_states() + + +class AUPRCMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.AUPRC + _computation_class: Type[RecMetricComputation] = AUPRCMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + if kwargs.get("grouped_auprc"): + self._required_inputs.add(GROUPING_KEYS) + if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION: + logging.warning( + f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet " + "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect." + ) diff --git a/torchrec/metrics/cali_free_ne.py b/torchrec/metrics/cali_free_ne.py new file mode 100644 index 000000000..82983f611 --- /dev/null +++ b/torchrec/metrics/cali_free_ne.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) +from torchrec.pt2.utils import pt2_compile_callable + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def _compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def compute_cali_free_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + weighted_sum_predictions: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + raw_ne = _compute_ne( + ce_sum=ce_sum, + weighted_num_samples=weighted_num_samples, + pos_labels=pos_labels, + neg_labels=neg_labels, + eta=eta, + ) + return raw_ne / ( + -pos_labels * torch.log2(weighted_sum_predictions / weighted_num_samples) + - (weighted_num_samples - pos_labels) + * torch.log2(1 - (weighted_sum_predictions / weighted_num_samples)) + ) + + +def get_cali_free_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_sum": torch.sum(cross_entropy, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + "weighted_sum_predictions": torch.sum(weights * predictions, dim=-1), + } + + +class CaliFreeNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for CaliFree NE, i.e. Normalized Entropy. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + allow_missing_label_with_zero_weight (bool): allow missing label to have weight 0, instead of throwing exception. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_sum_predictions", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + @pt2_compile_callable + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for CaliFreeNEMetricComputation update" + ) + states = get_cali_free_ne_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.CALI_FREE_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_cali_free_ne( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + cast(torch.Tensor, self.weighted_sum_predictions), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.CALI_FREE_NE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_cali_free_ne( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.get_window_state("weighted_sum_predictions"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class CaliFreeNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.CALI_FREE_NE + _computation_class: Type[RecMetricComputation] = CaliFreeNEMetricComputation diff --git a/torchrec/metrics/calibration.py b/torchrec/metrics/calibration.py index 7ba4ea201..3ef4b861a 100644 --- a/torchrec/metrics/calibration.py +++ b/torchrec/metrics/calibration.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, Type import torch @@ -16,6 +18,7 @@ RecMetricException, ) + CALIBRATION_NUM = "calibration_num" CALIBRATION_DENOM = "calibration_denom" @@ -69,6 +72,7 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: if predictions is None or weights is None: raise RecMetricException( diff --git a/torchrec/metrics/calibration_with_recalibration.py b/torchrec/metrics/calibration_with_recalibration.py new file mode 100644 index 000000000..fc7c594b9 --- /dev/null +++ b/torchrec/metrics/calibration_with_recalibration.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Optional, Type + +import torch +from torchrec.metrics.calibration import ( + CalibrationMetricComputation, + get_calibration_states, +) +from torchrec.metrics.metrics_namespace import MetricNamespace +from torchrec.metrics.rec_metric import ( + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +CALIBRATION_NUM = "calibration_num" +CALIBRATION_DENOM = "calibration_denom" + + +class RecalibratedCalibrationMetricComputation(CalibrationMetricComputation): + r""" + This class implements the RecMetricComputation for Calibration that is required to correctly estimate eval NE if negative downsampling was used during training. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, *args: Any, recalibration_coefficient: float = 1.0, **kwargs: Any + ) -> None: + self._recalibration_coefficient: float = recalibration_coefficient + super().__init__(*args, **kwargs) + self._add_state( + CALIBRATION_NUM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + CALIBRATION_DENOM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def _recalibrate( + self, + predictions: torch.Tensor, + calibration_coef: Optional[torch.Tensor], + ) -> torch.Tensor: + if calibration_coef is not None: + predictions = predictions / ( + predictions + (1.0 - predictions) / calibration_coef + ) + return predictions + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update" + ) + predictions = self._recalibrate( + predictions, self._recalibration_coefficient * torch.ones_like(predictions) + ) + num_samples = predictions.shape[-1] + for state_name, state_value in get_calibration_states( + labels, predictions, weights + ).items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + +class RecalibratedCalibrationMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALIBRATED_CALIBRATION + _computation_class: Type[RecMetricComputation] = ( + RecalibratedCalibrationMetricComputation + ) diff --git a/torchrec/metrics/ctr.py b/torchrec/metrics/ctr.py index 90d463d86..bc1088899 100644 --- a/torchrec/metrics/ctr.py +++ b/torchrec/metrics/ctr.py @@ -5,9 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, Type import torch + from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( MetricComputationReport, @@ -16,6 +19,7 @@ RecMetricException, ) + CTR_NUM = "ctr_num" CTR_DENOM = "ctr_denom" @@ -65,6 +69,7 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: if predictions is None or weights is None: raise RecMetricException( diff --git a/torchrec/metrics/gauc.py b/torchrec/metrics/gauc.py new file mode 100644 index 000000000..a42a36e1c --- /dev/null +++ b/torchrec/metrics/gauc.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch + +from torch.autograd.profiler import record_function +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +def compute_gauc_3d( + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, +) -> Dict[str, torch.Tensor]: + """Both predictions and labels are 3-d tensors in shape [n_task, n_group, n_sample].""" + + n_task, n_group, n_sample = predictions.size() + max_len = max(n_task, n_group, n_sample) + # Pre-register an arange to avoid multiple cpu=>gpu assignment. + pre_arange = torch.arange(max_len, device=predictions.device) + + with record_function("## gauc_argsort ##"): + sorted_indices = torch.argsort(predictions, dim=-1) + task_indices = ( + pre_arange[:n_task][:, None, None] + .expand(n_task, n_group, n_sample) + .contiguous() + .view(-1) + ) + group_indices = ( + pre_arange[:n_group][None, :, None] + .expand(n_task, n_group, n_sample) + .contiguous() + .view(-1) + ) + sample_indices = sorted_indices.contiguous().view(-1) + sorted_labels = labels[task_indices, group_indices, sample_indices].view( + n_task, n_group, n_sample + ) + sorted_weights = weights[task_indices, group_indices, sample_indices].view( + n_task, n_group, n_sample + ) + + with record_function("## gauc_calculation ##"): + pos_mask = sorted_labels + neg_mask = 1 - sorted_labels + + # cumulative negative *weight* that appear **before** each position + cum_neg_weight = torch.cumsum(sorted_weights * neg_mask, dim=-1) + + # contribution of every positive example: w_pos * (sum w_neg ranked lower) + contrib = pos_mask * sorted_weights * cum_neg_weight + numerator = contrib.sum(-1) # [n_task, n_group] + + w_pos = (pos_mask * sorted_weights).sum(-1) # [n_task, n_group] + w_neg = (neg_mask * sorted_weights).sum(-1) # [n_task, n_group] + denominator = w_pos * w_neg + + auc = numerator / (denominator + 1e-10) + + # Skip identical prediction sessions. + identical_prediction_mask = ~( + torch.all( + torch.logical_or( + predictions == predictions[:, :, 0:1], + predictions == 0, # avoid padding zeros. + ), + dim=-1, + ) + ) + # Skip identical label(all 0s/1s) sessions. + identical_label_mask = (w_pos > 0) & (w_neg > 0) + auc_mask = identical_label_mask * identical_prediction_mask + auc *= auc_mask + num_effective_samples = auc_mask.sum(-1) # [n_task] + auc = auc.sum(-1) # [n_task] + return {"auc_sum": auc, "num_samples": num_effective_samples} + + +def to_3d( + tensor_2d: torch.Tensor, seq_lengths: torch.Tensor, max_length: int +) -> torch.Tensor: + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths) + return torch.ops.fbgemm.jagged_2d_to_dense(tensor_2d, offsets, max_length) + + +@torch.compiler.disable +def get_auc_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + num_candidates: torch.Tensor, +) -> Dict[str, torch.Tensor]: + + # predictions, labels: [n_task, n_sample] + max_length = int(num_candidates.max().item()) + predictions_perm = predictions.permute(1, 0) + labels_perm = labels.permute(1, 0) + weights_perm = weights.permute(1, 0) + predictions_3d = to_3d(predictions_perm, num_candidates, max_length).permute( + 2, 0, 1 + ) + labels_3d = to_3d(labels_perm, num_candidates, max_length).permute(2, 0, 1) + weights_3d = to_3d(weights_perm, num_candidates, max_length).permute(2, 0, 1) + + return compute_gauc_3d( + predictions_3d, + labels_3d, + weights_3d, + ) + + +@torch.fx.wrap +def compute_window_auc( + auc: torch.Tensor, + num_samples: torch.Tensor, +) -> Dict[str, torch.Tensor]: + # [n_task] + return { + "gauc": (auc + 1e-9) / (num_samples + 2e-9), + "num_samples": num_samples, + } + + +class GAUCMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for GAUC, i.e. Session AUC. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + self._add_state( + "auc_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + num_candidates: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for GAUCMetricComputation update" + ) + + states = get_auc_states(labels, predictions, weights, num_candidates) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.GAUC, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_window_auc( + cast(torch.Tensor, self.auc_sum), + cast(torch.Tensor, self.num_samples), + )["gauc"], + ), + MetricComputationReport( + name=MetricName.GAUC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_window_auc( + self.get_window_state("auc_sum"), + self.get_window_state("num_samples"), + )["gauc"], + ), + MetricComputationReport( + name=MetricName.GAUC_NUM_SAMPLES, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_window_auc( + cast(torch.Tensor, self.auc_sum), + cast(torch.Tensor, self.num_samples), + )["num_samples"], + ), + MetricComputationReport( + name=MetricName.GAUC_NUM_SAMPLES, + metric_prefix=MetricPrefix.WINDOW, + value=compute_window_auc( + self.get_window_state("auc_sum"), + self.get_window_state("num_samples"), + )["num_samples"], + ), + ] + + return reports + + +class GAUCMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.GAUC + _computation_class: Type[RecMetricComputation] = GAUCMetricComputation diff --git a/torchrec/metrics/hindsight_target_pr.py b/torchrec/metrics/hindsight_target_pr.py new file mode 100644 index 000000000..800052ecf --- /dev/null +++ b/torchrec/metrics/hindsight_target_pr.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +TARGET_PRECISION = "target_precision" +THRESHOLD_GRANULARITY = 1000 + + +def compute_precision( + num_true_positives: torch.Tensor, num_false_positives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_positives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_positives).double(), + ) + + +def compute_recall( + num_true_positives: torch.Tensor, num_false_negitives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_negitives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_negitives), + ) + + +def compute_threshold_idx( + num_true_positives: torch.Tensor, + num_false_positives: torch.Tensor, + target_precision: float, +) -> int: + for i in range(THRESHOLD_GRANULARITY): + if ( + compute_precision(num_true_positives[i], num_false_positives[i]) + >= target_precision + ): + return i + + return THRESHOLD_GRANULARITY - 1 + + +def compute_true_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + return tp_sum + + +def compute_false_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + fp_sum[i] = torch.sum(weights * ((predictions >= threshold) * (1 - labels)), -1) + return fp_sum + + +def compute_false_neg_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + fn_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + fn_sum[i] = torch.sum(weights * ((predictions <= threshold) * labels), -1) + return fn_sum + + +def get_pr_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(predictions) + return { + "true_pos_sum": compute_true_pos_sum(labels, predictions, weights), + "false_pos_sum": compute_false_pos_sum(labels, predictions, weights), + "false_neg_sum": compute_false_neg_sum(labels, predictions, weights), + } + + +class HindsightTargetPRMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Hingsight Target PR. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + target_precision (float): If provided, computes the minimum threshold to achieve the target precision. + """ + + def __init__( + self, *args: Any, target_precision: float = 0.5, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "true_pos_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_pos_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_neg_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._target_precision: float = target_precision + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None: + raise RecMetricException( + "Inputs 'predictions' should not be None for HindsightTargetPRMetricComputation update" + ) + states = get_pr_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + true_pos_sum = cast(torch.Tensor, self.true_pos_sum) + false_pos_sum = cast(torch.Tensor, self.false_pos_sum) + false_neg_sum = cast(torch.Tensor, self.false_neg_sum) + threshold_idx = compute_threshold_idx( + true_pos_sum, + false_pos_sum, + self._target_precision, + ) + window_threshold_idx = compute_threshold_idx( + self.get_window_state("true_pos_sum"), + self.get_window_state("false_pos_sum"), + self._target_precision, + ) + reports = [ + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PR, + metric_prefix=MetricPrefix.LIFETIME, + value=torch.Tensor(threshold_idx), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PR, + metric_prefix=MetricPrefix.WINDOW, + value=torch.Tensor(window_threshold_idx), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PRECISION, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_precision( + true_pos_sum[threshold_idx], + false_pos_sum[threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PRECISION, + metric_prefix=MetricPrefix.WINDOW, + value=compute_precision( + self.get_window_state("true_pos_sum")[window_threshold_idx], + self.get_window_state("false_pos_sum")[window_threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_RECALL, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_recall( + true_pos_sum[threshold_idx], + false_neg_sum[threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_RECALL, + metric_prefix=MetricPrefix.WINDOW, + value=compute_recall( + self.get_window_state("true_pos_sum")[window_threshold_idx], + self.get_window_state("false_neg_sum")[window_threshold_idx], + ), + ), + ] + return reports + + +class HindsightTargetPRMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.HINDSIGHT_TARGET_PR + _computation_class: Type[RecMetricComputation] = HindsightTargetPRMetricComputation diff --git a/torchrec/metrics/mae.py b/torchrec/metrics/mae.py new file mode 100644 index 000000000..bdb4562ab --- /dev/null +++ b/torchrec/metrics/mae.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +ERROR_SUM = "error_sum" +WEIGHTED_NUM_SAMPES = "weighted_num_samples" + + +def compute_mae( + error_sum: torch.Tensor, weighted_num_samples: torch.Tensor +) -> torch.Tensor: + return torch.where( + weighted_num_samples == 0.0, 0.0, error_sum / weighted_num_samples + ).double() + + +def compute_error_sum( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * torch.abs(labels - predictions), dim=-1) + + +def get_mae_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> Dict[str, torch.Tensor]: + return { + "error_sum": compute_error_sum(labels, predictions, weights), + "weighted_num_samples": torch.sum(weights, dim=-1), + } + + +class MAEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for MAE, i.e. Mean Absolute Error. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "error_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for MAEMetricComputation update" + ) + states = get_mae_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.MAE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_mae( + cast(torch.Tensor, self.error_sum), + cast(torch.Tensor, self.weighted_num_samples), + ), + ), + MetricComputationReport( + name=MetricName.MAE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_mae( + self.get_window_state(ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_SAMPES), + ), + ), + ] + + +class MAEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.MAE + _computation_class: Type[RecMetricComputation] = MAEMetricComputation diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 70bbd4a76..8ca849152 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import abc @@ -15,15 +17,26 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.profiler import record_function +from torchrec.metrics.accuracy import AccuracyMetric from torchrec.metrics.auc import AUCMetric +from torchrec.metrics.auprc import AUPRCMetric +from torchrec.metrics.cali_free_ne import CaliFreeNEMetric from torchrec.metrics.calibration import CalibrationMetric +from torchrec.metrics.calibration_with_recalibration import ( + RecalibratedCalibrationMetric, +) from torchrec.metrics.ctr import CTRMetric +from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric +from torchrec.metrics.mae import MAEMetric from torchrec.metrics.metrics_config import ( + BatchSizeStage, MetricsConfig, RecMetricEnum, RecMetricEnumBase, RecTaskInfo, StateMetricEnum, + validate_batch_size_stages, ) from torchrec.metrics.metrics_namespace import ( compose_customized_metric_key, @@ -32,19 +45,63 @@ ) from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.mse import MSEMetric +from torchrec.metrics.multiclass_recall import MulticlassRecallMetric +from torchrec.metrics.ndcg import NDCGMetric from torchrec.metrics.ne import NEMetric +from torchrec.metrics.ne_positive import NEPositiveMetric +from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric +from torchrec.metrics.output import OutputMetric +from torchrec.metrics.precision import PrecisionMetric +from torchrec.metrics.precision_session import PrecisionSessionMetric +from torchrec.metrics.rauc import RAUCMetric from torchrec.metrics.rec_metric import RecMetric, RecMetricList +from torchrec.metrics.recall import RecallMetric +from torchrec.metrics.recall_session import RecallSessionMetric +from torchrec.metrics.scalar import ScalarMetric +from torchrec.metrics.segmented_ne import SegmentedNEMetric +from torchrec.metrics.serving_calibration import ServingCalibrationMetric +from torchrec.metrics.serving_ne import ServingNEMetric +from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric from torchrec.metrics.throughput import ThroughputMetric +from torchrec.metrics.tower_qps import TowerQPSMetric +from torchrec.metrics.unweighted_ne import UnweightedNEMetric +from torchrec.metrics.weighted_avg import WeightedAvgMetric +from torchrec.metrics.xauc import XAUCMetric logger: logging.Logger = logging.getLogger(__name__) REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = { RecMetricEnum.NE: NEMetric, + RecMetricEnum.NE_POSITIVE: NEPositiveMetric, + RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric, + RecMetricEnum.RECALIBRATED_NE: RecalibratedNEMetric, + RecMetricEnum.RECALIBRATED_CALIBRATION: RecalibratedCalibrationMetric, RecMetricEnum.CTR: CTRMetric, RecMetricEnum.CALIBRATION: CalibrationMetric, RecMetricEnum.AUC: AUCMetric, + RecMetricEnum.AUPRC: AUPRCMetric, + RecMetricEnum.RAUC: RAUCMetric, RecMetricEnum.MSE: MSEMetric, + RecMetricEnum.MAE: MAEMetric, + RecMetricEnum.MULTICLASS_RECALL: MulticlassRecallMetric, + RecMetricEnum.WEIGHTED_AVG: WeightedAvgMetric, + RecMetricEnum.TOWER_QPS: TowerQPSMetric, + RecMetricEnum.RECALL_SESSION_LEVEL: RecallSessionMetric, + RecMetricEnum.PRECISION_SESSION_LEVEL: PrecisionSessionMetric, + RecMetricEnum.ACCURACY: AccuracyMetric, + RecMetricEnum.NDCG: NDCGMetric, + RecMetricEnum.XAUC: XAUCMetric, + RecMetricEnum.SCALAR: ScalarMetric, + RecMetricEnum.PRECISION: PrecisionMetric, + RecMetricEnum.RECALL: RecallMetric, + RecMetricEnum.SERVING_NE: ServingNEMetric, + RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric, + RecMetricEnum.OUTPUT: OutputMetric, + RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric, + RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric, + RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric, + RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric, } @@ -52,9 +109,6 @@ MODEL_METRIC_LABEL: str = "model" -MEMORY_AVG_WARNING_PERCENTAGE = 20 -MEMORY_AVG_WARNING_WARMUP = 100 - MetricValue = Union[torch.Tensor, float] @@ -89,7 +143,7 @@ class RecMetricModule(nn.Module): throughput_metric (Optional[ThroughputMetric]): the ThroughputMetric. state_metrics (Optional[Dict[str, StateMetric]]): the dict of StateMetrics. compute_interval_steps (int): the intervals between two compute calls in the unit of batch number - memory_usage_limit_mb (float): the memory usage limit for OOM check + memory_usage_limit_mb (float): [Unused] the memory usage limit for OOM check Call Args: Not supported. @@ -120,8 +174,6 @@ class RecMetricModule(nn.Module): rec_metrics: RecMetricList throughput_metric: Optional[ThroughputMetric] state_metrics: Dict[str, StateMetric] - memory_usage_limit_mb: float - memory_usage_mb_avg: float oom_count: int compute_count: int last_compute_time: float @@ -138,6 +190,7 @@ def __init__( compute_interval_steps: int = 100, min_compute_interval: float = 0.0, max_compute_interval: float = float("inf"), + # Unused, but needed for backwards compatibility. TODO: Remove from callsites memory_usage_limit_mb: float = 512, ) -> None: super().__init__() @@ -148,8 +201,6 @@ def __init__( self.trained_batches: int = 0 self.batch_size = batch_size self.world_size = world_size - self.memory_usage_limit_mb = memory_usage_limit_mb - self.memory_usage_mb_avg = 0.0 self.oom_count = 0 self.compute_count = 0 @@ -173,61 +224,39 @@ def __init__( ) self.last_compute_time = -1.0 - def get_memory_usage(self) -> int: - r"""Total memory of unique RecMetric tensors in bytes""" - total = {} - for metric in self.rec_metrics.rec_metrics: - total.update(metric.get_memory_usage()) - return sum(total.values()) - - def check_memory_usage(self, compute_count: int) -> None: - memory_usage_mb = self.get_memory_usage() / (10**6) - if memory_usage_mb > self.memory_usage_limit_mb: - self.oom_count += 1 - logger.warning( - f"MetricModule is using {memory_usage_mb}MB. " - f"This is larger than the limit{self.memory_usage_limit_mb}MB. " - f"This is the f{self.oom_count}th OOM." - ) - - if ( - compute_count > MEMORY_AVG_WARNING_WARMUP - and memory_usage_mb - > self.memory_usage_mb_avg * ((100 + MEMORY_AVG_WARNING_PERCENTAGE) / 100) - ): - logger.warning( - f"MetricsModule is using more than {MEMORY_AVG_WARNING_PERCENTAGE}% of " - f"the average memory usage. Current usage: {memory_usage_mb}MB." - ) - - self.memory_usage_mb_avg = ( - self.memory_usage_mb_avg * (compute_count - 1) + memory_usage_mb - ) / compute_count - - def _update_rec_metrics(self, model_out: Dict[str, torch.Tensor]) -> None: + def _update_rec_metrics( + self, model_out: Dict[str, torch.Tensor], **kwargs: Any + ) -> None: r"""the internal update function to parse the model output. Override this function if the implementation cannot support the model output format. """ if self.rec_metrics and self.rec_tasks: - labels, predictions, weights = parse_task_model_outputs( - self.rec_tasks, model_out + labels, predictions, weights, required_inputs = parse_task_model_outputs( + self.rec_tasks, model_out, self.get_required_inputs() ) + if required_inputs: + kwargs["required_inputs"] = required_inputs + self.rec_metrics.update( - predictions=predictions, labels=labels, weights=weights + predictions=predictions, + labels=labels, + weights=weights, + **kwargs, ) - def update(self, model_out: Dict[str, torch.Tensor]) -> None: + def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None: r"""update() is called per batch, usually right after forward() to update the local states of metrics based on the model_output. Throughput.update() is also called due to the implementation sliding window throughput. """ - self._update_rec_metrics(model_out) - if self.throughput_metric: - self.throughput_metric.update() - self.trained_batches += 1 + with record_function("## RecMetricModule:update ##"): + self._update_rec_metrics(model_out, **kwargs) + if self.throughput_metric: + self.throughput_metric.update() + self.trained_batches += 1 def _adjust_compute_interval(self) -> None: """ @@ -287,22 +316,21 @@ def compute(self) -> Dict[str, MetricValue]: right before logging the metrics results to the data sink. """ self.compute_count += 1 - self.check_memory_usage(self.compute_count) - ret: Dict[str, MetricValue] = {} - if self.rec_metrics: - self._adjust_compute_interval() - ret.update(self.rec_metrics.compute()) - if self.throughput_metric: - ret.update(self.throughput_metric.compute()) - if self.state_metrics: - for namespace, component in self.state_metrics.items(): - ret.update( - { - f"{compose_customized_metric_key(namespace, metric_name)}": metric_value - for metric_name, metric_value in component.get_metrics().items() - } - ) + with record_function("## RecMetricModule:compute ##"): + if self.rec_metrics: + self._adjust_compute_interval() + ret.update(self.rec_metrics.compute()) + if self.throughput_metric: + ret.update(self.throughput_metric.compute()) + if self.state_metrics: + for namespace, component in self.state_metrics.items(): + ret.update( + { + f"{compose_customized_metric_key(namespace, metric_name)}": metric_value + for metric_name, metric_value in component.get_metrics().items() + } + ) return ret def local_compute(self) -> Dict[str, MetricValue]: @@ -323,6 +351,9 @@ def unsync(self) -> None: def reset(self) -> None: self.rec_metrics.reset() + def get_required_inputs(self) -> Optional[List[str]]: + return self.rec_metrics.get_required_inputs() + def _generate_rec_metrics( metrics_config: MetricsConfig, @@ -338,6 +369,8 @@ def _generate_rec_metrics( if metric_def and metric_def.arguments is not None: kwargs = metric_def.arguments + kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile + rec_tasks: List[RecTaskInfo] = [] if metric_def.rec_tasks and metric_def.rec_task_indices: raise ValueError( @@ -386,9 +419,9 @@ def _generate_state_metrics( ) -> Dict[str, StateMetric]: state_metrics: Dict[str, StateMetric] = {} for metric_enum in metrics_config.state_metrics: - metric_namespace: Optional[ - MetricNamespace - ] = STATE_METRICS_NAMESPACE_MAPPING.get(metric_enum, None) + metric_namespace: Optional[MetricNamespace] = ( + STATE_METRICS_NAMESPACE_MAPPING.get(metric_enum, None) + ) if metric_namespace is None: raise ValueError(f"Unknown StateMetrics {metric_enum}") full_namespace = compose_metric_namespace( @@ -407,15 +440,25 @@ def generate_metric_module( state_metrics_mapping: Dict[StateMetricEnum, StateMetric], device: torch.device, process_group: Optional[dist.ProcessGroup] = None, + batch_size_stages: Optional[List[BatchSizeStage]] = None, ) -> RecMetricModule: rec_metrics = _generate_rec_metrics( metrics_config, world_size, my_rank, batch_size, process_group ) + """ + Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so + different training jobs have aligned mertics. + TODO: update metrics other than ThroughputMetric if it has dependency on batch_size + """ + validate_batch_size_stages(batch_size_stages) + if metrics_config.throughput_metric: throughput_metric = ThroughputMetric( batch_size=batch_size, world_size=world_size, window_seconds=metrics_config.throughput_metric.window_size, + warmup_steps=metrics_config.throughput_metric.warmup_steps, + batch_size_stages=batch_size_stages, ) else: throughput_metric = None diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index cc3fe9a8e..0428e2412 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional @@ -18,10 +20,47 @@ class RecMetricEnumBase(StrValueMixin, Enum): class RecMetricEnum(RecMetricEnumBase): NE = "ne" + NE_POSITIVE = "ne_positive" + RECALIBRATED_NE = "recalibrated_ne" + RECALIBRATED_CALIBRATION = "recalibrated_calibration" + SEGMENTED_NE = "segmented_ne" + LOG_LOSS = "log_loss" CTR = "ctr" AUC = "auc" + AUPRC = "auprc" + RAUC = "rauc" CALIBRATION = "calibration" MSE = "mse" + MAE = "mae" + MULTICLASS_RECALL = "multiclass_recall" + RECALL_SESSION_LEVEL = "recall_session_level" + PRECISION_SESSION_LEVEL = "precision_session_level" + WEIGHTED_AVG = "weighted_avg" + TOWER_QPS = "tower_qps" + ACCURACY = "accuracy" + NDCG = "ndcg" + XAUC = "xauc" + SCALAR = "scalar" + PRECISION = "precision" + RECALL = "recall" + SERVING_NE = "serving_ne" + SERVING_CALIBRATION = "serving_calibration" + OUTPUT = "output" + TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" + HINDSIGHT_TARGET_PR = "hindsight_target_pr" + + +@dataclass(unsafe_hash=True, eq=True) +class SessionMetricDef: + # hyperparameters required for session level metrics + # session_var_name: name of session tensor in the model_out + # top_threshold: predictiones ranked in top "top_threshold" are considered as positive + # run_ranking_of_labels: if True, labels are also ranked as predictions + session_var_name: str + top_threshold: Optional[int] = None + run_ranking_of_labels: bool = False @dataclass(unsafe_hash=True, eq=True) @@ -30,22 +69,34 @@ class RecTaskInfo: label_name: str = "label" prediction_name: str = "prediction" weight_name: str = "weight" + session_metric_def: Optional[SessionMetricDef] = ( + None # used for session level metrics + ) + is_negative_task: bool = False + tensor_name: Optional[str] = None + weighted: bool = True class RecComputeMode(Enum): """This Enum lists the supported computation modes for RecMetrics. FUSED_TASKS_COMPUTATION indicates that RecMetrics will fuse the computation - for multiple tasks of the same metric. This can be used by modules where the - outputs of all the tasks are vectorized. + for multiple tasks of the same metric. This can be used by modules where the + outputs of all the tasks are vectorized. + FUSED_TASKS_AND_STATES_COMPUTATION fuse both the tasks (same as FUSED_TASKS_COMPUTATION) + and states (e.g. calibration_num and calibration_denom for caliration) of the + same metric. This currently only supports 1D state tensors (e.g. when all state + tensors are of the same (n_tasks) shape). """ FUSED_TASKS_COMPUTATION = 1 UNFUSED_TASKS_COMPUTATION = 2 + FUSED_TASKS_AND_STATES_COMPUTATION = 3 _DEFAULT_WINDOW_SIZE = 10_000_000 _DEFAULT_THROUGHPUT_WINDOW_SECONDS = 100 +_DEFAULT_THROUGHPUT_WARMUP_STEPS = 100 @dataclass @@ -59,7 +110,8 @@ class RecMetricDef: RecTask information stored in the parent ``MetricsConfig``. Only one of the two fields should be specified. rec_task_indices (List[int]): see the doscstring of ``rec_tasks``. - window_size (int): the window size for this metric. + window_size (int): the window size for this metric. Note that this is global window size. + The local window size is window_size / world_size, and must be larger than batch size. arguments (Optional[Dict[str, Any]]): any propritary arguments to be used by this Metric. """ @@ -78,6 +130,7 @@ class StateMetricEnum(StrValueMixin, Enum): @dataclass class ThroughputDef: window_size: int = _DEFAULT_THROUGHPUT_WINDOW_SECONDS + warmup_steps: int = _DEFAULT_THROUGHPUT_WARMUP_STEPS @dataclass @@ -116,6 +169,7 @@ class MetricsConfig: should_validate_update (bool): whether to check the inputs of update() and skip update if the inputs are invalid. Invalid inputs include the case where all examples have 0 weights for a batch. + enable_pt2_compile (bool): whether to enable PT2 compilation for metrics. """ rec_tasks: List[RecTaskInfo] = field(default_factory=list) @@ -129,6 +183,7 @@ class MetricsConfig: max_compute_interval: float = float("inf") compute_on_all_ranks: bool = False should_validate_update: bool = False + enable_pt2_compile: bool = False DefaultTaskInfo = RecTaskInfo( @@ -158,3 +213,35 @@ class MetricsConfig: throughput_metric=None, state_metrics=[], ) + + +@dataclass +class BatchSizeStage: + """ + BatchSizeStage class for defining the variable batch size stage. + For a List[BatchSizeStage], the max_iter should be in ascending order, and the last one should have max_iter=None + Attributes + ---------- + batch_size(int): A multiple of base_batch_size + max_iter(int): The maximum number of iterations for the stage. + When previous BatchSizeStage.max_iters < iter <= max_iters, the stage is effective. + Max_iter is the absolute train iteration count, not the relative count within each stage + """ + + batch_size: int = 0 + max_iters: Optional[int] = 0 + + +def validate_batch_size_stages( + batch_size_stages: Optional[List[BatchSizeStage]], +) -> None: + if not batch_size_stages: + return + + if len(batch_size_stages) == 0: + raise ValueError("Batch size stages should not be empty") + + if batch_size_stages[-1].max_iters is not None: + raise ValueError( + f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}" + ) diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index dff0bdacd..e952d368c 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """ The namespace definition and genration APIs for rec metrics. @@ -21,6 +23,7 @@ """ from enum import Enum +from typing import Optional class StrValueMixin: @@ -37,14 +40,52 @@ class MetricName(MetricNameBase): DEFAULT = "" NE = "ne" + NE_POSITIVE = "ne_positive" + SEGMENTED_NE = "segmented_ne" + LOG_LOSS = "logloss" THROUGHPUT = "throughput" TOTAL_EXAMPLES = "total_examples" + ATTEMPT_EXAMPLES = "attempt_examples" + BATCH_SIZE = "batch_size" CTR = "ctr" CALIBRATION = "calibration" MSE = "mse" + MAE = "mae" RMSE = "rmse" AUC = "auc" + GAUC = "gauc" + AUPRC = "auprc" + RAUC = "rauc" + GROUPED_AUC = "grouped_auc" + GROUPED_AUPRC = "grouped_auprc" + GROUPED_RAUC = "grouped_rauc" + RECALL_SESSION_LEVEL = "recall_session_level" + PRECISION_SESSION_LEVEL = "precision_session_level" MULTICLASS_RECALL = "multiclass_recall" + WEIGHTED_AVG = "weighted_avg" + TOWER_QPS = "qps" + ACCURACY = "accuracy" + NDCG = "ndcg" + XAUC = "xauc" + SCALAR = "scalar" + OUTPUT = "output" + + GAUC_NUM_SAMPLES = "gauc_num_samples" + TOTAL_POSITIVE_EXAMPLES = "total_positive_examples" + TOTAL_NEGATIVE_EXAMPLES = "total_negative_examples" + PRECISION = "precision" + RECALL = "recall" + + SERVING_NE = "serving_ne" + SERVING_CALIBRATION = "serving_calibration" + TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" + + HINDSIGHT_TARGET_PR = "hindsight_target_pr" + HINDSIGHT_TARGET_PRECISION = "hindsight_target_precision" + HINDSIGHT_TARGET_RECALL = "hindsight_target_recall" class MetricNamespaceBase(StrValueMixin, Enum): @@ -55,22 +96,56 @@ class MetricNamespace(MetricNamespaceBase): DEFAULT = "" NE = "ne" + NE_POSITIVE = "ne_positive" + SEGMENTED_NE = "segmented_ne" + RECALIBRATED_NE = "recalibrated_ne" + RECALIBRATED_CALIBRATION = "recalibrated_calibration" THROUGHPUT = "throughput" CTR = "ctr" CALIBRATION = "calibration" MSE = "mse" AUC = "auc" + GAUC = "gauc" + AUPRC = "auprc" + RAUC = "rauc" + MAE = "mae" + ACCURACY = "accuracy" OPTIMIZERS = "optimizers" MODEL_CONFIGURATOR = "model_configurator" MULTICLASS_RECALL = "multiclass_recall" + WEIGHTED_AVG = "weighted_avg" + RECALL_SESSION_LEVEL = "recall_session_level" + PRECISION_SESSION_LEVEL = "precision_session_level" + + TOWER_QPS = "qps" + NDCG = "ndcg" + XAUC = "xauc" + + SCALAR = "scalar" + + PRECISION = "precision" + RECALL = "recall" + + SERVING_NE = "serving_ne" + SERVING_CALIBRATION = "serving_calibration" + + OUTPUT = "output" + TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" + + HINDSIGHT_TARGET_PR = "hindsight_target_pr" + class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" LIFETIME = "lifetime_" WINDOW = "window_" + ATTEMPT = "attempt_" def task_wildcard_metrics_pattern( @@ -97,12 +172,13 @@ def compose_metric_namespace( def compose_customized_metric_key( namespace: str, metric_name: str, + description: Optional[str] = None, ) -> str: r"""Get the metric key. The input are unrestricted (string) namespace and metric_name. This API should only be used by compose_metric_key() and state metrics as the keys of state metrics are unknown. """ - return f"{namespace}|{metric_name}" + return f"{namespace}|{metric_name}{description or ''}" def compose_metric_key( @@ -110,8 +186,11 @@ def compose_metric_key( task_name: str, metric_name: MetricNameBase, metric_prefix: MetricPrefix = MetricPrefix.DEFAULT, + description: Optional[str] = None, ) -> str: r"""Get the metric key based on the input parameters""" return compose_customized_metric_key( - compose_metric_namespace(namespace, task_name), f"{metric_prefix}{metric_name}" + compose_metric_namespace(namespace, task_name), + f"{metric_prefix}{metric_name}", + description, ) diff --git a/torchrec/metrics/model_utils.py b/torchrec/metrics/model_utils.py index bd78d2954..eafe5aacb 100644 --- a/torchrec/metrics/model_utils.py +++ b/torchrec/metrics/model_utils.py @@ -5,12 +5,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import logging from typing import Dict, List, Optional, Tuple import torch from torchrec.metrics.rec_metric import RecTaskInfo +logger: logging.Logger = logging.getLogger(__name__) + + +def session_ids_to_tensor( + session_ids: List[str], + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + This function is used to prepare model outputs with session_ids as List[str] to tensor to be consumed by the Metric computation + """ + curr_id = 1 + session_lengths_list = [0] + + for i, session in enumerate(session_ids[:-1]): + if session == session_ids[i + 1]: + session_lengths_list.append(curr_id) + else: + session_lengths_list.append(curr_id) + curr_id += 1 + + session_lengths_list.append(curr_id) + return torch.tensor(session_lengths_list[1:], device=device) + + def is_empty_signals( labels: torch.Tensor, predictions: torch.Tensor, @@ -41,13 +68,48 @@ def parse_model_outputs( if not is_empty_signals(labels, predictions, weights): if labels.dim() == predictions.dim(): - assert (torch.numel(labels) == torch.numel(predictions)) and ( - torch.numel(labels) == torch.numel(weights) - ), ( - "Expect the same number of elements in labels, predictions, and weights. " - f"Instead got {torch.numel(labels)}, {torch.numel(predictions)}, " - f"{torch.numel(weights)}" + # For vector valued label and prediction we should have shapes + # labels.size() == (batch_size, dim_vector_valued_label) + # predictions.size() == (batch_size, dim_vector_valued_prediction) + # weights.size() == (batch_size,) + is_vector_valued_label_and_prediction = ( + (labels.dim() == 2) + and (weights.dim() == 1) + and (labels.size()[0] == predictions.size()[0]) + and (labels.size()[0] == weights.size()[0]) ) + if is_vector_valued_label_and_prediction: + logger.warning( + f""" + Vector valued labels and predictions are provided. + + For vector valued label and prediction we should have shapes + labels.shape: (batch_size, dim_vector_valued_label) + predictions.shape: (batch_size, dim_vector_valued_prediction) + weights.shape: (batch_size,) + + The provided labels, predictions and weights comply with the conditions for vector valued labels and predictions. + These conditions are: + 1. labels.dim() == 2 + 2. predictions.dim() == 2 + 3. weights.dim() == 1 + 4. labels.size()[0] == predictions.size()[0] + 5. labels.size()[0] == weights.size()[0] + + The shapes of labels, predictions and weights are: + labels.shape == {labels.shape}, + predictions.shape == {predictions.shape}, + weights.shape == {weights.shape} + """ + ) + else: + assert (torch.numel(labels) == torch.numel(predictions)) and ( + torch.numel(labels) == torch.numel(weights) + ), ( + "Expect the same number of elements in labels, predictions, and weights. " + f"Instead got {torch.numel(labels)}, {torch.numel(predictions)}, " + f"{torch.numel(weights)}" + ) else: # For multiclass models, labels.size() = (batch_size), and predictions.size() = (batch_size, number_of_classes) assert torch.numel(labels) == torch.numel(predictions) / predictions.size()[ -1 @@ -62,12 +124,43 @@ def parse_model_outputs( return labels, predictions, weights +def parse_required_inputs( + model_out: Dict[str, torch.Tensor], + required_inputs_list: List[str], + ndcg_transform_input: bool = False, + device: Optional[torch.device] = None, +) -> Dict[str, torch.Tensor]: + required_inputs: Dict[str, torch.Tensor] = {} + for feature in required_inputs_list: + # convert feature defined from config only + if ndcg_transform_input: + model_out[feature] = ( + # pyre-ignore[6] + session_ids_to_tensor(model_out[feature], device=device) + if isinstance(model_out[feature], list) + else model_out[feature] + ) + required_inputs[feature] = model_out[feature].squeeze() + assert isinstance(required_inputs[feature], torch.Tensor) + return required_inputs + + def parse_task_model_outputs( - tasks: List[RecTaskInfo], model_out: Dict[str, torch.Tensor] -) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + tasks: List[RecTaskInfo], + model_out: Dict[str, torch.Tensor], + required_inputs_list: Optional[List[str]] = None, +) -> Tuple[ + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], +]: all_labels: Dict[str, torch.Tensor] = {} all_predictions: Dict[str, torch.Tensor] = {} all_weights: Dict[str, torch.Tensor] = {} + all_required_inputs: Dict[str, torch.Tensor] = {} + # Convert session_ids to tensor if NDCG metric + ndcg_transform_input = False for task in tasks: labels, predictions, weights = parse_model_outputs( task.label_name, task.prediction_name, task.weight_name, model_out @@ -81,4 +174,15 @@ def parse_task_model_outputs( if torch.numel(labels) > 0: all_labels[task.name] = labels - return all_labels, all_predictions, all_weights + if task.name and task.name.startswith("ndcg"): + ndcg_transform_input = True + + if required_inputs_list is not None: + all_required_inputs = parse_required_inputs( + model_out, + required_inputs_list, + ndcg_transform_input, + device=labels.device, + ) + + return all_labels, all_predictions, all_weights, all_required_inputs diff --git a/torchrec/metrics/mse.py b/torchrec/metrics/mse.py index 79612eac6..371f86daa 100644 --- a/torchrec/metrics/mse.py +++ b/torchrec/metrics/mse.py @@ -5,9 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, Type import torch + from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( MetricComputationReport, @@ -84,6 +87,7 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: if predictions is None or weights is None: raise RecMetricException( diff --git a/torchrec/metrics/multiclass_recall.py b/torchrec/metrics/multiclass_recall.py index 02b017055..50545dc43 100644 --- a/torchrec/metrics/multiclass_recall.py +++ b/torchrec/metrics/multiclass_recall.py @@ -5,9 +5,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import logging from typing import Any, cast, Dict, List, Optional, Type import torch +from torchrec.metrics.metrics_config import RecComputeMode from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( @@ -18,10 +22,13 @@ ) -def get_true_positives_list( - predictions: Optional[torch.Tensor], +logger: logging.Logger = logging.getLogger(__name__) + + +def compute_true_positives_at_k( + predictions: torch.Tensor, labels: torch.Tensor, - weights: Optional[torch.Tensor], + weights: torch.Tensor, n_classes: int, ) -> torch.Tensor: """ @@ -42,58 +49,58 @@ def get_true_positives_list( >>> predictions = torch.tensor([[0.9, 0.1, 0, 0, 0], [0.1, 0.2, 0.25, 0.15, 0.3], [0, 1.0, 0, 0, 0], [0, 0, 0.2, 0.7, 0.1]]) >>> labels = torch.tensor([0, 3, 1, 2]) - >>> weights = torch.tensor([1, 2, 2, 1]) + >>> weights = torch.tensor([1, 0.25, 0.5, 0.25]) >>> n_classes = 5 >>> true_positives_list = compute_multiclass_k_sum(predictions, labels, n_classes) >>> true_positives_list - tensor([3., 4., 4., 6., 6.]) + tensor([1.5000, 1.7500, 1.7500, 2.0000, 2.0000]) """ ranks = torch.argsort(predictions, dim=-1, descending=True) true_positives = ( - torch.zeros(1) + torch.zeros(1, device=predictions.device) if predictions.ndim == 2 - else torch.zeros(predictions.shape[0], 1) + else torch.zeros(predictions.shape[0], 1, device=predictions.device) ) - true_positives_list = torch.tensor([]) + true_positives_list = torch.tensor([], device=predictions.device) for k in range(n_classes): mask = torch.unsqueeze(labels, dim=-1) == ranks[..., k : k + 1] - mask *= torch.unsqueeze(weights, dim=-1) + mask = mask * torch.unsqueeze(weights, dim=-1) true_positives += mask.sum(dim=-2) true_positives_list = torch.cat((true_positives_list, true_positives), dim=-1) return true_positives_list -def compute_multiclass_recall_at_k_sum( - tp_sum: torch.Tensor, +def compute_multiclass_recall_at_k( + tp_at_k: torch.Tensor, total_weights: torch.Tensor, ) -> torch.Tensor: - return tp_sum / torch.unsqueeze(total_weights, dim=-1) + return tp_at_k / torch.unsqueeze(total_weights, dim=-1) def get_multiclass_recall_states( - predictions: Optional[torch.Tensor], + predictions: torch.Tensor, labels: torch.Tensor, - weights: Optional[torch.Tensor], + weights: torch.Tensor, n_classes: int, ) -> Dict[str, torch.Tensor]: - true_positives_list = get_true_positives_list( + true_positives_at_k_sum = compute_true_positives_at_k( predictions, labels, weights, n_classes ) return { - "tp_sum": true_positives_list, + "tp_at_k": true_positives_at_k_sum, "total_weights": torch.sum(weights, dim=-1), } class MulticlassRecallMetricComputation(RecMetricComputation): def __init__(self, *args: Any, **kwargs: Any) -> None: - self._n_classes: int = kwargs.pop("n_classes") + self._n_classes: int = kwargs.pop("number_of_classes") super().__init__(*args, **kwargs) self._add_state( - "tp_sum", + "tp_at_k", torch.zeros(self._n_tasks, self._n_classes, dtype=torch.double), add_window_state=True, dist_reduce_fx="sum", @@ -113,13 +120,15 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: if predictions is None or weights is None: raise RecMetricException( "Inputs 'predictions' and 'weights' should not be None for MulticlassRecallMetricComputation update" ) - - states = get_multiclass_recall_states(predictions, labels, self._n_classes) + states = get_multiclass_recall_states( + predictions, labels, weights, self._n_classes + ) num_samples = predictions.shape[-2] for state_name, state_value in states.items(): state = getattr(self, state_name) @@ -131,16 +140,16 @@ def _compute(self) -> List[MetricComputationReport]: MetricComputationReport( name=MetricName.MULTICLASS_RECALL, metric_prefix=MetricPrefix.LIFETIME, - value=compute_multiclass_recall_at_k_sum( - cast(torch.Tensor, self.tp_sum), + value=compute_multiclass_recall_at_k( + cast(torch.Tensor, self.tp_at_k), cast(torch.Tensor, self.total_weights), ), ), MetricComputationReport( name=MetricName.MULTICLASS_RECALL, metric_prefix=MetricPrefix.WINDOW, - value=compute_multiclass_recall_at_k_sum( - self.get_window_state("tp_sum"), + value=compute_multiclass_recall_at_k( + self.get_window_state("tp_at_k"), self.get_window_state("total_weights"), ), ), @@ -150,3 +159,11 @@ def _compute(self) -> List[MetricComputationReport]: class MulticlassRecallMetric(RecMetric): _namespace: MetricNamespace = MetricNamespace.MULTICLASS_RECALL _computation_class: Type[RecMetricComputation] = MulticlassRecallMetricComputation + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION: + logging.warning( + f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet " + "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect." + ) diff --git a/torchrec/metrics/ndcg.py b/torchrec/metrics/ndcg.py new file mode 100644 index 000000000..718e208f2 --- /dev/null +++ b/torchrec/metrics/ndcg.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union + +import torch +from torch import distributed as dist +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +SUM_NDCG = "sum_ndcg" +NUM_SESSIONS = "num_sessions" +REQUIRED_INPUTS = "required_inputs" +SESSION_KEY = "session_id" + + +def _validate_model_outputs( + *, + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + session_ids: torch.Tensor, +) -> None: + # Sanity check dimensions. + assert predictions.shape == labels.shape == weights.shape == session_ids.shape + assert ( + predictions.dim() == 2 and predictions.shape[0] > 0 and predictions.shape[1] > 0 + ) + assert (session_ids[0] == session_ids).all() + + +def _get_adjusted_ndcg_inputs( + *, + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, + session_ids: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Remove all single-length sessions from all variables. + """ + # Get unique session IDs and their corresponding indices to put them in range (O, N]. + _, converted_session_ids, session_lengths = session_ids[0].unique( + return_inverse=True, return_counts=True + ) + + example_to_length = torch.gather( + session_lengths, + dim=-1, + index=converted_session_ids.type(torch.int64), + ) + example_corresponds_to_session_with_length_greater_than_one = example_to_length > 1 + + # Remove all single-length sessions. + return ( + predictions[:, example_corresponds_to_session_with_length_greater_than_one], + labels[:, example_corresponds_to_session_with_length_greater_than_one], + weights[:, example_corresponds_to_session_with_length_greater_than_one], + converted_session_ids[ + example_corresponds_to_session_with_length_greater_than_one + ], + ) + + +def _get_ndcg_states( + *, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + session_ids: torch.Tensor, + exponential_gain: bool, + k: int = -1, # In case we want to support NDCG @ K in the future. + report_ndcg_as_decreasing_curve: bool = True, + remove_single_length_sessions: bool = False, + scale_by_weights_tensor: bool = False, +) -> Dict[str, torch.Tensor]: + """ + Normalized Discounted Cumulative Gain (NDCG) @ k. + + TODO(@venkatrsrinivas): Refactor into smaller helper functions :) + """ + # Remove all single-length sessions from all variables. + if remove_single_length_sessions: + ( + adjusted_predictions, + adjusted_labels, + adjusted_weights, + adjusted_session_ids, + ) = _get_adjusted_ndcg_inputs( + predictions=predictions, + labels=labels, + weights=weights, + session_ids=session_ids, + ) + else: + ( + adjusted_predictions, + adjusted_labels, + adjusted_weights, + adjusted_session_ids, + ) = (predictions, labels, weights, session_ids[0]) + + # If we are scaling by weights, then do that before NDCG computation. + if scale_by_weights_tensor: + adjusted_labels = adjusted_weights * adjusted_labels + adjusted_predictions = adjusted_weights * adjusted_predictions + + # Helper variables for all reshaping below. + num_tasks, batch_size = adjusted_labels.shape + + # Get unique session IDs and their corresponding indices to put them in range (O, N]. + ( + unique_session_ids, + converted_session_ids, + session_lengths, + ) = adjusted_session_ids.unique(return_inverse=True, return_counts=True) + + # Healthy assertion that we are trimming sessions correctly. + if remove_single_length_sessions: + assert (session_lengths > 1).all() + + num_sessions = unique_session_ids.shape[0] + + # Return early => no state update if there are no sessions. + if num_sessions == 0: + return {} + + max_session_length = torch.max(session_lengths) + max_session_length = ( + max_session_length + if k == -1 + else torch.min(torch.tensor(k), max_session_length) + ) + + # Convert session IDs to [num_tasks, num_sessions] from [num_sessions,]. + expanded_session_ids = converted_session_ids.expand(num_tasks, -1) + + # Sort labels by themselves and also by predictions. + sorted_labels_by_labels, sorted_labels_indices = adjusted_labels.sort( + descending=True, dim=-1 + ) + _, sorted_predictions_indices = adjusted_predictions.sort(descending=True, dim=-1) + sorted_labels_by_predictions = torch.gather( + adjusted_labels, + dim=-1, + index=sorted_predictions_indices, + ) + + # Expand these to be [num_task, num_sessions, batch_size] for masking to handle later. + expanded_sorted_labels_by_labels = sorted_labels_by_labels.unsqueeze(1).expand( + (num_tasks, num_sessions, batch_size) + ) + expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.unsqueeze( + 1 + ).expand((num_tasks, num_sessions, batch_size)) + + # Make sure to correspondingly sort session IDs according to how we sorted labels above. + session_ids_by_sorted_labels = torch.gather( + expanded_session_ids, + dim=-1, + index=sorted_labels_indices, + ) + session_ids_by_sorted_predictions = torch.gather( + expanded_session_ids, + dim=-1, + index=sorted_predictions_indices, + ) + + # Helper variable to track every session ID's examples for every task. + task_to_session_to_examples = ( + torch.arange(num_sessions) + .view(1, num_sessions, 1) + .expand(num_tasks, -1, batch_size) + ).to(device=labels.device) + + # Figure out after sorting which example indices belong to which session. + sorted_session_ids_by_labels_mask = ( + task_to_session_to_examples == session_ids_by_sorted_labels.unsqueeze(1) + ).long() + sorted_session_ids_by_predictions_mask = ( + task_to_session_to_examples == session_ids_by_sorted_predictions.unsqueeze(1) + ).long() + + # Get the ranks (1, N] for each example in each session for every task. + label_by_label_ranks = (sorted_session_ids_by_labels_mask).cumsum(dim=-1) + label_by_prediction_ranks = (sorted_session_ids_by_predictions_mask).cumsum(dim=-1) + + # Compute coresponding discount factors (according to sorting). + ( + discounts_for_label_by_label, + discounts_for_label_by_prediction, + ) = torch.reciprocal(torch.log2(label_by_label_ranks + 1)), torch.reciprocal( + torch.log2(label_by_prediction_ranks + 1) + ) + + # Account for edge cases and when we want to compute NDCG @ K. + ( + discounts_for_label_by_label[label_by_label_ranks <= 0], + discounts_for_label_by_prediction[label_by_prediction_ranks <= 0], + ) = ( + 0.0, + 0.0, + ) + ( + discounts_for_label_by_label[label_by_label_ranks > max_session_length], + discounts_for_label_by_prediction[ + label_by_prediction_ranks > max_session_length + ], + ) = ( + 0.0, + 0.0, + ) + + # Apply mask => to correctly compute ideal and observed gains before applying discounts. + ideal_gains = expanded_sorted_labels_by_labels * sorted_session_ids_by_labels_mask + observed_gains = ( + expanded_sorted_labels_by_predictions * sorted_session_ids_by_predictions_mask + ) + + # Apply exponential gain if applicable. + ideal_gains = torch.exp2(ideal_gains) - 1.0 if exponential_gain else ideal_gains + observed_gains = ( + torch.exp2(observed_gains) - 1.0 if exponential_gain else observed_gains + ) + + # Apply discounts and sum. + ideal_dcg = torch.sum(ideal_gains * discounts_for_label_by_label, dim=-1) + ideal_dcg[ideal_dcg == 0] = 1e-6 # Avoid division by 0. + + observed_dcg = torch.sum( + observed_gains * discounts_for_label_by_prediction, + dim=-1, + ) + ndcg = observed_dcg / ideal_dcg + + max_weights = ( + torch.zeros((num_tasks, num_sessions), dtype=weights.dtype) + .to(device=adjusted_weights.device) + .scatter_reduce_( + dim=-1, + index=expanded_session_ids, + src=adjusted_weights, # [num_tasks, batch_size] + reduce="amax", + ) + ) + + # Scale NDCG by max weight per session. + ndcg_report = (1 - ndcg) if report_ndcg_as_decreasing_curve else ndcg + ndcg_report = ndcg_report.to(device=labels.device) + + # If we aren't scaling gains by weight tensor, + # just scale by max_weight per session to match weird production logic. + if not scale_by_weights_tensor: + ndcg_report *= max_weights + + final_ndcg_report = torch.sum( + ndcg_report, dim=-1 + ) # Sum over num_sessions for losses => [num_tasks] + + return { + SUM_NDCG: final_ndcg_report, + NUM_SESSIONS: torch.full((num_tasks,), fill_value=num_sessions).to( + device=converted_session_ids.device + ), + } + + +def _compute_ndcg( + *, sum_ndcg: torch.Tensor, num_sessions: torch.Tensor +) -> torch.Tensor: + return sum_ndcg / num_sessions + + +class NDCGComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for NDCG @ K + (i.e., Normalized Discounted Cumulative Gain @ K). + + Specially this reports (1 - NDCG) so that TensorBoard + can capture a decreasing "loss" as opposed to an increasing "gain" + to visualize similarly to normalized entropy (NE) / pointwise measures. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, + *args: Any, + exponential_gain: bool = False, + session_key: str = SESSION_KEY, + k: int = -1, + report_ndcg_as_decreasing_curve: bool = True, + remove_single_length_sessions: bool = False, + scale_by_weights_tensor: bool = False, + is_negative_task_mask: Optional[List[bool]] = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._exponential_gain: bool = exponential_gain + self._session_key: str = session_key + self._k: int = k + self._remove_single_length_sessions: bool = remove_single_length_sessions + self._is_negative_task_mask: Optional[List[bool]] = is_negative_task_mask + self._report_ndcg_as_decreasing_curve: bool = report_ndcg_as_decreasing_curve + self._scale_by_weights_tensor: bool = scale_by_weights_tensor + self._add_state( + SUM_NDCG, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_SESSIONS, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + """ + Arguments: + predictions: Tensor of size (n_task, n_examples) + labels: Tensor of size (n_task, n_examples) + weights: Tensor of size (n_task, n_examples) + Returns: + Nothing => updates state. + """ + if ( + REQUIRED_INPUTS not in kwargs + or self._session_key not in kwargs[REQUIRED_INPUTS] + ): + raise RecMetricException( + f"{self._session_key=} should be in {kwargs=} as input. It is required to calculate NDCG loss." + ) + + session_ids = kwargs[REQUIRED_INPUTS][self._session_key] + + if predictions is None or weights is None or session_ids is None: + raise RecMetricException( + "Inputs 'predictions', 'weights' and 'session_ids' should not be None for NDCGMetricComputation update" + ) + + # Apply negative scaling to predictions so that + # we can accurately compute NDCG for negative tasks + # (e.g., NDCG_p(skip) prefers to have label => 0 + # towards the top of the relevant ranked list > label = 1) + # Or maybe, we want to compute NDCG p(skip) the same way as NDCG p(like). + # In either case, this mask gives us the full control. + if self._is_negative_task_mask is not None: + predictions[self._is_negative_task_mask, :] = ( + 1 - predictions[self._is_negative_task_mask, :] + ) + labels[self._is_negative_task_mask, :] = ( + 1 - labels[self._is_negative_task_mask, :] + ) + + _validate_model_outputs( + predictions=predictions, + labels=labels, + weights=weights, + session_ids=session_ids, + ) + + predictions = predictions.double() + labels = labels.double() + weights = weights.double() + + # Calculate NDCG loss at current iterations. + states = _get_ndcg_states( + labels=labels, + predictions=predictions, + weights=weights, + session_ids=session_ids, + exponential_gain=self._exponential_gain, + remove_single_length_sessions=self._remove_single_length_sessions, + report_ndcg_as_decreasing_curve=self._report_ndcg_as_decreasing_curve, + k=self._k, + scale_by_weights_tensor=self._scale_by_weights_tensor, + ) + + # Update based on the new states. + for state_name, state_value in states.items(): + state = getattr(self, state_name).to(labels.device) + state += state_value + self._aggregate_window_state(state_name, state_value, predictions.shape[-1]) + + def _compute(self) -> List[MetricComputationReport]: + + return [ + MetricComputationReport( + name=MetricName.NDCG, + metric_prefix=MetricPrefix.LIFETIME, + value=_compute_ndcg( + sum_ndcg=cast(torch.Tensor, getattr(self, SUM_NDCG)), + num_sessions=cast(torch.Tensor, getattr(self, NUM_SESSIONS)), + ), + ), + MetricComputationReport( + name=MetricName.NDCG, + metric_prefix=MetricPrefix.WINDOW, + value=_compute_ndcg( + sum_ndcg=self.get_window_state(SUM_NDCG), + num_sessions=self.get_window_state(NUM_SESSIONS), + ), + ), + ] + + +class NDCGMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NDCG + _computation_class: Type[RecMetricComputation] = NDCGComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + # This is the required metadata to be enriched with + # the session ID information by loss wrappers, etc. + # This is set through the front-end configurations, + # => fallback back to "session_id" if not specified. + if "session_key" not in kwargs: + self._required_inputs.add(SESSION_KEY) + else: + # pyre-ignore[6] + self._required_inputs.add(kwargs["session_key"]) + + def _get_task_kwargs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Dict[str, Any]: + all_task_info = ( + [task_config] if isinstance(task_config, RecTaskInfo) else task_config + ) + + # Just sanity in weird case if we have no tasks (should never happen). + if len(all_task_info) == 0: + return {} + return { + "is_negative_task_mask": [ + task_info.is_negative_task for task_info in all_task_info + ] + } diff --git a/torchrec/metrics/ne.py b/torchrec/metrics/ne.py index a3b179c7a..41f14a92e 100644 --- a/torchrec/metrics/ne.py +++ b/torchrec/metrics/ne.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, cast, Dict, List, Optional, Type import torch @@ -44,18 +46,38 @@ def _compute_cross_entropy_norm( ) +@torch.fx.wrap def compute_ne( ce_sum: torch.Tensor, weighted_num_samples: torch.Tensor, pos_labels: torch.Tensor, neg_labels: torch.Tensor, eta: float, + allow_missing_label_with_zero_weight: bool = False, ) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) mean_label = pos_labels / weighted_num_samples ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) return ce_sum / ce_norm +def compute_logloss( + ce_sum: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + labels_sum = pos_labels + neg_labels + labels_sum.clamp_(min=eta) + return ce_sum / labels_sum + + def get_ne_states( labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, eta: float ) -> Dict[str, torch.Tensor]: @@ -79,9 +101,22 @@ class NEMetricComputation(RecMetricComputation): The constructor arguments are defined in RecMetricComputation. See the docstring of RecMetricComputation for more detail. + + Args: + include_logloss (bool): return vanilla logloss as one of metrics results, on top of NE. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + *args: Any, + include_logloss: bool = False, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._include_logloss: bool = include_logloss + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) super().__init__(*args, **kwargs) self._add_state( "cross_entropy_sum", @@ -119,6 +154,7 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: if predictions is None or weights is None: raise RecMetricException( @@ -133,7 +169,7 @@ def update( self._aggregate_window_state(state_name, state_value, num_samples) def _compute(self) -> List[MetricComputationReport]: - return [ + reports = [ MetricComputationReport( name=MetricName.NE, metric_prefix=MetricPrefix.LIFETIME, @@ -143,6 +179,7 @@ def _compute(self) -> List[MetricComputationReport]: cast(torch.Tensor, self.pos_labels), cast(torch.Tensor, self.neg_labels), self.eta, + self._allow_missing_label_with_zero_weight, ), ), MetricComputationReport( @@ -154,9 +191,34 @@ def _compute(self) -> List[MetricComputationReport]: self.get_window_state("pos_labels"), self.get_window_state("neg_labels"), self.eta, + self._allow_missing_label_with_zero_weight, ), ), ] + if self._include_logloss: + reports += [ + MetricComputationReport( + name=MetricName.LOG_LOSS, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_logloss( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + ), + ), + MetricComputationReport( + name=MetricName.LOG_LOSS, + metric_prefix=MetricPrefix.WINDOW, + value=compute_logloss( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + ), + ), + ] + return reports class NEMetric(RecMetric): diff --git a/torchrec/metrics/ne_positive.py b/torchrec/metrics/ne_positive.py new file mode 100644 index 000000000..2d2147f3d --- /dev/null +++ b/torchrec/metrics/ne_positive.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +def compute_cross_entropy_positive( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy_positive = -weights * labels * torch.log2(predictions) + return cross_entropy_positive + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def compute_ne_positive( + ce_positive_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_positive_sum / ce_norm + + +def get_ne_positive_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, eta: float +) -> Dict[str, torch.Tensor]: + cross_entropy_positive = compute_cross_entropy_positive( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_positive_sum": torch.sum(cross_entropy_positive, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + } + + +class NEPositiveMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for NE positive, i.e. Normalized Entropy where label = 1 + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_positive_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for NEMetricComputation update" + ) + states = get_ne_positive_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.NE_POSITIVE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_ne_positive( + cast(torch.Tensor, self.cross_entropy_positive_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.NE_POSITIVE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_ne_positive( + self.get_window_state("cross_entropy_positive_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class NEPositiveMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NE_POSITIVE + _computation_class: Type[RecMetricComputation] = NEPositiveMetricComputation diff --git a/torchrec/metrics/ne_with_recalibration.py b/torchrec/metrics/ne_with_recalibration.py new file mode 100644 index 000000000..715ddd356 --- /dev/null +++ b/torchrec/metrics/ne_with_recalibration.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Optional, Type + +import torch + +from torchrec.metrics.metrics_namespace import MetricNamespace +from torchrec.metrics.ne import get_ne_states, NEMetricComputation +from torchrec.metrics.rec_metric import ( + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +class RecalibratedNEMetricComputation(NEMetricComputation): + r""" + This class implements the recalibration for NE that is required to correctly estimate eval NE if negative downsampling was used during training. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + include_logloss (bool): return vanilla logloss as one of metrics results, on top of NE. + """ + + def __init__( + self, + *args: Any, + include_logloss: bool = False, + allow_missing_label_with_zero_weight: bool = False, + recalibration_coefficient: float = 1.0, + **kwargs: Any, + ) -> None: + self._recalibration_coefficient: float = recalibration_coefficient + self._include_logloss: bool = include_logloss + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + def _recalibrate( + self, + predictions: torch.Tensor, + calibration_coef: Optional[torch.Tensor], + ) -> torch.Tensor: + if calibration_coef is not None: + predictions = predictions / ( + predictions + (1.0 - predictions) / calibration_coef + ) + return predictions + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for RecalibratedNEMetricComputation update" + ) + + predictions = self._recalibrate( + predictions, self._recalibration_coefficient * torch.ones_like(predictions) + ) + states = get_ne_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + +class RecalibratedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALIBRATED_NE + _computation_class: Type[RecMetricComputation] = RecalibratedNEMetricComputation diff --git a/torchrec/metrics/output.py b/torchrec/metrics/output.py new file mode 100644 index 000000000..6ca0f149f --- /dev/null +++ b/torchrec/metrics/output.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, List, Optional, Type + +import torch +from torch import distributed as dist + +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecComputeMode, + RecMetric, + RecMetricComputation, + RecMetricException, + RecTaskInfo, +) + + +class OutputMetricComputation(RecMetricComputation): + """ + Metric that logs whatever model outputs are given in kwargs + TODO - make this generic metric that can be used for any model output tensor + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "latest_imp", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=False, + dist_reduce_fx="sum", + persistent=False, + ) + self._add_state( + "total_latest_imp", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=False, + dist_reduce_fx="sum", + persistent=False, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + required_list = ["latest_imp", "total_latest_imp"] + if "required_inputs" not in kwargs or not all( + item in kwargs["required_inputs"] for item in required_list + ): + raise RecMetricException( + "OutputMetricComputation requires 'latest_imp' and 'total_latest_imp' in kwargs" + ) + states = { + "latest_imp": kwargs["required_inputs"]["latest_imp"] + .float() + .mean(dim=-1, dtype=torch.double), + "total_latest_imp": kwargs["required_inputs"]["total_latest_imp"] + .float() + .mean(dim=-1, dtype=torch.double), + } + + for state_name, state_value in states.items(): + setattr(self, state_name, state_value) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.OUTPUT, + metric_prefix=MetricPrefix.DEFAULT, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Tensor, Module]`. + value=self.latest_imp, + description="_latest_imp", + ), + MetricComputationReport( + name=MetricName.OUTPUT, + metric_prefix=MetricPrefix.DEFAULT, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Tensor, Module]`. + value=self.total_latest_imp, + description="_total_latest_imp", + ), + ] + + +class OutputMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.OUTPUT + _computation_class: Type[RecMetricComputation] = OutputMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + self._required_inputs.add("latest_imp") + self._required_inputs.add("total_latest_imp") diff --git a/torchrec/metrics/precision.py b/torchrec/metrics/precision.py new file mode 100644 index 000000000..c069077bb --- /dev/null +++ b/torchrec/metrics/precision.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +THRESHOLD = "threshold" + + +def compute_precision( + num_true_positives: torch.Tensor, num_false_positives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_positives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_positives).double(), + ) + + +def compute_true_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * ((predictions >= threshold) * labels), dim=-1) + + +def compute_false_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * ((predictions >= threshold) * (1 - labels)), dim=-1) + + +def get_precision_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], + threshold: float = 0.5, +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(predictions) + return { + "true_pos_sum": compute_true_pos_sum(labels, predictions, weights, threshold), + "false_pos_sum": compute_false_pos_sum(labels, predictions, weights, threshold), + } + + +class PrecisionMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Precision. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + threshold (float): If provided, computes Precision metrics cutting off at + the specified threshold. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "true_pos_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_pos_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._threshold: float = threshold + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None: + raise RecMetricException( + "Inputs 'predictions' should not be None for PrecisionMetricComputation update" + ) + states = get_precision_states(labels, predictions, weights, self._threshold) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.PRECISION, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_precision( + cast(torch.Tensor, self.true_pos_sum), + cast(torch.Tensor, self.false_pos_sum), + ), + ), + MetricComputationReport( + name=MetricName.PRECISION, + metric_prefix=MetricPrefix.WINDOW, + value=compute_precision( + self.get_window_state("true_pos_sum"), + self.get_window_state("false_pos_sum"), + ), + ), + ] + return reports + + +class PrecisionMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.PRECISION + _computation_class: Type[RecMetricComputation] = PrecisionMetricComputation diff --git a/torchrec/metrics/precision_session.py b/torchrec/metrics/precision_session.py new file mode 100644 index 000000000..daa4864fc --- /dev/null +++ b/torchrec/metrics/precision_session.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Any, cast, Dict, List, Optional, Set, Type, Union + +import torch +from torch import distributed as dist +from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecComputeMode, + RecMetric, + RecMetricComputation, + RecMetricException, +) +from torchrec.metrics.recall_session import ( + _calc_num_true_pos, + _validate_model_outputs, + ranking_within_session, +) + +logger: logging.Logger = logging.getLogger(__name__) + +NUM_TRUE_POS = "num_true_pos" +NUM_FALSE_POS = "num_false_pos" + + +def _calc_num_false_pos( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + # predictions are expected to be 0 or 1 integers. + num_false_pos = torch.sum( + weights * (1 - labels) * (predictions == 1).double(), dim=-1 + ) + return num_false_pos + + +def _calc_precision( + num_true_pos: torch.Tensor, num_false_pos: torch.Tensor +) -> torch.Tensor: + # if num_true_pos + num_false_pos == 0 then we set precision = NaN by default. + precision = torch.tensor([float("nan")]) + if (num_true_pos + num_false_pos).item() != 0: + precision = num_true_pos / (num_true_pos + num_false_pos) + else: + logger.warning( + "precision = NaN. Likely, it means that there were no positive predictions passed to the metric yet." + " Please, debug if you expect every batch to include positive predictions." + ) + return precision + + +class PrecisionSessionMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for precision on session level. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + """ + + def __init__( + self, + *args: Any, + session_metric_def: SessionMetricDef, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._add_state( + NUM_TRUE_POS, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_FALSE_POS, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.top_threshold: Optional[int] = session_metric_def.top_threshold + self.run_ranking_of_labels: bool = session_metric_def.run_ranking_of_labels + self.session_var_name: Optional[str] = session_metric_def.session_var_name + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + """ + Args: + predictions (torch.Tensor): tensor of size (n_task, n_examples) + labels (torch.Tensor): tensor of size (n_task, n_examples) + weights (torch.Tensor): tensor of size (n_task, n_examples) + session (torch.Tensor): Optional tensor of size (n_task, n_examples) that specifies the groups of + predictions/labels per batch. + """ + + if ( + "required_inputs" not in kwargs + or self.session_var_name not in kwargs["required_inputs"] + ): + raise RecMetricException( + "Need the {} input to update the session metric".format( + self.session_var_name + ) + ) + # pyre-ignore + session = kwargs["required_inputs"][self.session_var_name] + if predictions is None or weights is None or session is None: + raise RecMetricException( + "Inputs 'predictions', 'weights' and 'session' should not be None for PrecisionSessionMetricComputation update" + ) + _validate_model_outputs(labels, predictions, weights, session) + + predictions = predictions.double() + labels = labels.double() + weights = weights.double() + + num_samples = predictions.shape[-1] + for state_name, state_value in self.get_precision_states( + labels=labels, predictions=predictions, weights=weights, session=session + ).items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.PRECISION_SESSION_LEVEL, + metric_prefix=MetricPrefix.LIFETIME, + value=_calc_precision( + num_true_pos=cast(torch.Tensor, getattr(self, NUM_TRUE_POS)), + num_false_pos=cast(torch.Tensor, getattr(self, NUM_FALSE_POS)), + ), + ), + MetricComputationReport( + name=MetricName.PRECISION_SESSION_LEVEL, + metric_prefix=MetricPrefix.WINDOW, + value=_calc_precision( + num_true_pos=self.get_window_state(NUM_TRUE_POS), + num_false_pos=self.get_window_state(NUM_FALSE_POS), + ), + ), + ] + + def get_precision_states( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + session: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + predictions_ranked = ranking_within_session(predictions, session) + # pyre-fixme[58]: `<` is not supported for operand types `Tensor` and + # `Optional[int]`. + predictions_labels = (predictions_ranked < self.top_threshold).to(torch.int32) + if self.run_ranking_of_labels: + labels_ranked = ranking_within_session(labels, session) + # pyre-fixme[58]: `<` is not supported for operand types `Tensor` and + # `Optional[int]`. + labels = (labels_ranked < self.top_threshold).to(torch.int32) + num_true_pos = _calc_num_true_pos(labels, predictions_labels, weights) + num_false_pos = _calc_num_false_pos(labels, predictions_labels, weights) + + return {NUM_TRUE_POS: num_true_pos, NUM_FALSE_POS: num_false_pos} + + +class PrecisionSessionMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.PRECISION_SESSION_LEVEL + _computation_class: Type[RecMetricComputation] = PrecisionSessionMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Any, + ) -> None: + if compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: + raise RecMetricException( + "Fused computation is not supported for precision session-level metrics" + ) + + if fused_update_limit > 0: + raise RecMetricException( + "Fused update is not supported for precision session-level metrics" + ) + for task in tasks: + if task.session_metric_def is None: + raise RecMetricException( + "Please, specify the session metric definition" + ) + session_metric_def = task.session_metric_def + if session_metric_def.top_threshold is None: + raise RecMetricException("Please, specify the top threshold") + + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + process_group=process_group, + **kwargs, + ) + + def _get_task_kwargs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Dict[str, Any]: + if isinstance(task_config, list): + raise RecMetricException("Session metric can only take one task at a time") + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + + return {"session_metric_def": task_config.session_metric_def} + + def _get_task_required_inputs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Set[str]: + if isinstance(task_config, list): + raise RecMetricException("Session metric can only take one task at a time") + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + + return ( + {task_config.session_metric_def.session_var_name} + if task_config.session_metric_def.session_var_name + else set() + ) diff --git a/torchrec/metrics/rauc.py b/torchrec/metrics/rauc.py new file mode 100644 index 000000000..cf28ed19a --- /dev/null +++ b/torchrec/metrics/rauc.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from functools import partial +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed as dist +from torchmetrics.utilities.distributed import gather_all_tensors +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +PREDICTIONS = "predictions" +LABELS = "labels" +WEIGHTS = "weights" +GROUPING_KEYS = "grouping_keys" +REQUIRED_INPUTS = "required_inputs" + + +def _concat_if_needed( + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This check exists because of how the state is organized due to quirks in RecMetrics. + Since we do not do tensor concatenatation in the compute or update call, there are cases (in non-distributed settings) + where the tensors from updates are not concatted into a single tensor. Which is determined by the length of the list. + """ + preds_t, labels_t, weights_t = None, None, None + if len(predictions) > 1: + preds_t = torch.cat(predictions, dim=-1) + labels_t = torch.cat(labels, dim=-1) + weights_t = torch.cat(weights, dim=-1) + else: + preds_t = predictions[0] + labels_t = labels[0] + weights_t = weights[0] + + return preds_t, labels_t, weights_t + + +def count_reverse_pairs_divide_and_conquer(input: List[float]) -> float: + + n = len(input) + total_inversions = divide(input, 0, len(input) - 1) + + return total_inversions / (n * (n - 1) / 2) + + +def divide(input: List[float], low: int, high: int) -> int: + if low >= high: + return 0 + + mid = low + (high - low) // 2 + + left_inversions = divide(input, low, mid) + right_inversions = divide(input, mid + 1, high) + merge_inversions = conquer_and_count(input, low, mid, high) + + return left_inversions + right_inversions + merge_inversions + + +def conquer_and_count( + input: List[float], left_index: int, mid_index: int, right_index: int +) -> int: + left = input[left_index : mid_index + 1] + right = input[mid_index + 1 : right_index + 1] + + i, j, k, inversions = 0, 0, left_index, 0 + + while i < len(left) and j < len(right): + if left[i] <= right[j]: + input[k] = left[i] + i += 1 + else: + input[k] = right[j] + j += 1 + count = (mid_index + 1) - (left_index + i) + inversions += count + k += 1 + + while i < len(left): + input[k] = left[i] + i += 1 + k += 1 + + while j < len(right): + input[k] = right[j] + j += 1 + k += 1 + + return inversions + + +def _compute_rauc_helper( + predictions: torch.Tensor, + labels: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + + array = [ + x + for x, _ in sorted( + zip(labels.tolist(), predictions.tolist()), key=lambda x: (x[1], x[0]) + ) + ] + + return torch.tensor(1 - count_reverse_pairs_divide_and_conquer(array)) + + +def compute_rauc( + n_tasks: int, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], +) -> torch.Tensor: + """ + Computes RAUC (Regression AUC) for regression tasks. + + Args: + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + labels (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples). + """ + + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) + raucs = [] + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): + rauc = _compute_rauc_helper(predictions_i, labels_i, weights_i) + raucs.append(rauc.view(1)) + return torch.cat(raucs) + + +def compute_rauc_per_group( + n_tasks: int, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], + grouping_keys: torch.Tensor, +) -> torch.Tensor: + """ + Computes RAUC (Regression AUC) for regression tasks for groups of predictions/labels. + Args: + n_tasks (int): number of tasks + predictions (List[torch.Tensor]): tensor of size (n_tasks, n_examples) + labels (List[torch.Tensor]: tensor of size (n_tasks, n_examples) + weights (List[torch.Tensor]): tensor of size (n_tasks, n_examples) + grouping_keys (torch.Tensor): tensor of size (n_examples,) + + Returns: + torch.Tensor: tensor of size (n_tasks,), average of RAUCs per group. + """ + preds_t, labels_t, weights_t = _concat_if_needed(predictions, labels, weights) + raucs = [] + if grouping_keys.numel() != 0 and grouping_keys[0] == -1: + # we added padding as the first elements during init to avoid floating point exception in sync() + # removing the paddings to avoid numerical errors. + grouping_keys = grouping_keys[1:] + + # get unique group indices + group_indices = torch.unique(grouping_keys) + + for predictions_i, labels_i, weights_i in zip(preds_t, labels_t, weights_t): + # Loop over each group + rauc_groups_sum = torch.tensor([0], dtype=torch.float32) + for group_idx in group_indices: + # get predictions, labels, and weights for this group + group_mask = grouping_keys == group_idx + grouped_predictions = predictions_i[group_mask] + grouped_labels = labels_i[group_mask] + grouped_weights = weights_i[group_mask] + + rauc = _compute_rauc_helper( + grouped_predictions, grouped_labels, grouped_weights + ) + rauc_groups_sum = rauc_groups_sum.to(rauc.device) + rauc_groups_sum += rauc.view(1) + avg_rauc = ( + rauc_groups_sum / len(group_indices) + if len(group_indices) > 0 + else torch.tensor([0.5], dtype=torch.float32) + ) + raucs.append(avg_rauc) + return torch.cat(raucs) + + +def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]: + return [torch.cat(state, dim=dim)] + + +# pyre-ignore +_grouping_keys_state_reduction = partial(_state_reduction, dim=0) + + +class RAUCMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for RAUC, i.e. Regression AUC. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + Args: + grouped_rauc (bool): If True, computes RAUC per group and returns average RAUC across all groups. + The `grouping_keys` is provided during state updates along with predictions, labels, weights. + This feature is currently not enabled for `fused_update_limit`. + """ + + def __init__( + self, + *args: Any, + grouped_rauc: bool = False, + fused_update_limit: int = 0, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if grouped_rauc and fused_update_limit > 0: + raise RecMetricException( + "Grouped RAUC and Fused Update Limit cannot be enabled together yet." + ) + + self._grouped_rauc: bool = grouped_rauc + self._num_samples: int = 0 + self._add_state( + PREDICTIONS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + self._add_state( + LABELS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + self._add_state( + WEIGHTS, + [], + add_window_state=False, + dist_reduce_fx=_state_reduction, + persistent=False, + ) + if self._grouped_rauc: + self._add_state( + GROUPING_KEYS, + [], + add_window_state=False, + dist_reduce_fx=_grouping_keys_state_reduction, + persistent=False, + ) + self._init_states() + + # The states values are set to empty lists in __init__() and reset(), and then we + # add a size (self._n_tasks, 1) tensor to each of the list as the initial values + # This is to bypass the limitation of state aggregation in TorchMetrics sync() when + # we try to checkpoint the states before update() + # The reason for using lists here is to avoid automatically stacking the tensors from + # all the trainers into one tensor in sync() + # The reason for using non-empty tensors as the first elements is to avoid the + # floating point exception thrown in sync() for aggregating empty tensors + def _init_states(self) -> None: + if len(getattr(self, PREDICTIONS)) > 0: + return + self._num_samples = 0 + getattr(self, PREDICTIONS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + getattr(self, LABELS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + getattr(self, WEIGHTS).append( + torch.zeros((self._n_tasks, 1), dtype=torch.float, device=self.device) + ) + if self._grouped_rauc: + getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + """ + Args: + predictions (torch.Tensor): tensor of size (n_task, n_examples) + labels (torch.Tensor): tensor of size (n_task, n_examples) + weights (torch.Tensor): tensor of size (n_task, n_examples) + grouping_key (torch.Tensor): Optional tensor of size (1, n_examples) that specifies the groups of + predictions/labels per batch. If provided, the RAUC metric also + computes RAUC per group and returns the average RAUC across all groups. + """ + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for RAUCMetricComputation update" + ) + predictions = predictions.float() + labels = labels.float() + weights = weights.float() + batch_size = predictions.size(-1) + start_index = max(self._num_samples + batch_size - self._window_size, 0) + + # Using `self.predictions =` will cause Pyre errors. + w_preds = getattr(self, PREDICTIONS) + w_labels = getattr(self, LABELS) + w_weights = getattr(self, WEIGHTS) + + # remove init states + if self._num_samples == 0: + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + + w_preds.append(predictions) + w_labels.append(labels) + w_weights.append(weights) + + self._num_samples += batch_size + + while self._num_samples > self._window_size: + diff = self._num_samples - self._window_size + if diff > w_preds[0].size(-1): + self._num_samples -= w_preds[0].size(-1) + # Remove the first element from predictions, labels, and weights + for lst in [w_preds, w_labels, w_weights]: + lst.pop(0) + else: + # Update the first element of predictions, labels, and weights + # Off by one potentially - keeping legacy behaviour + for lst in [w_preds, w_labels, w_weights]: + lst[0] = lst[0][:, diff:] + # if empty tensor, remove it + if torch.numel(lst[0]) == 0: + lst.pop(0) + self._num_samples -= diff + + if self._grouped_rauc: + if REQUIRED_INPUTS not in kwargs or ( + (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None + ): + raise RecMetricException( + f"Input '{GROUPING_KEYS}' are required for RAUCMetricComputation grouped update" + ) + getattr(self, GROUPING_KEYS)[0] = torch.cat( + [ + cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:], + grouping_keys.squeeze(), + ], + dim=0, + ) + + def _compute(self) -> List[MetricComputationReport]: + reports = [] + reports.append( + MetricComputationReport( + name=MetricName.RAUC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_rauc( + self._n_tasks, + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + ), + ) + ) + + if self._grouped_rauc: + reports.append( + MetricComputationReport( + name=MetricName.GROUPED_RAUC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_rauc_per_group( + self._n_tasks, + cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + cast(List[torch.Tensor], getattr(self, LABELS)), + cast(List[torch.Tensor], getattr(self, WEIGHTS)), + cast(torch.Tensor, getattr(self, GROUPING_KEYS))[0], + ), + ) + ) + return reports + + def _sync_dist( + self, + dist_sync_fn: Callable = gather_all_tensors, # pyre-ignore[24] + process_group: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + """ + This function is overridden from torchmetric.Metric, since for RAUC we want to concat the tensors + right before the allgather collective is called. It directly changes the attributes/states, which + is ok because end of function sets the attributes to reduced values + """ + for attr in self._reductions: # pragma: no cover + val = getattr(self, attr) + if isinstance(val, list) and len(val) > 1: + setattr(self, attr, [torch.cat(val, dim=-1)]) + super()._sync_dist(dist_sync_fn, process_group) + + def reset(self) -> None: + super().reset() + self._init_states() + + +class RAUCMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RAUC + _computation_class: Type[RecMetricComputation] = RAUCMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + if kwargs.get("grouped_rauc"): + self._required_inputs.add(GROUPING_KEYS) + if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION: + logging.warning( + f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet " + "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect." + ) diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index f12d3e8c3..53fbfa3b5 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -5,9 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import abc +import inspect import itertools import math from collections import defaultdict, deque @@ -24,6 +27,7 @@ Mapping, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -33,7 +37,9 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.profiler import record_function from torchmetrics import Metric +from torchrec.distributed.types import get_tensor_size_bytes from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo from torchrec.metrics.metrics_namespace import ( compose_metric_key, @@ -41,6 +47,7 @@ MetricNamespaceBase, MetricPrefix, ) +from torchrec.pt2.utils import pt2_compile_callable RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]] @@ -51,11 +58,12 @@ class MetricComputationReport: name: MetricNameBase metric_prefix: MetricPrefix value: torch.Tensor + description: Optional[str] = None DefaultValueT = TypeVar("DefaultValueT") ComputeIterType = Iterator[ - Tuple[RecTaskInfo, MetricNameBase, torch.Tensor, MetricPrefix] + Tuple[RecTaskInfo, MetricNameBase, torch.Tensor, MetricPrefix, str] ] MAX_BUFFER_COUNT = 1000 @@ -116,6 +124,7 @@ class RecMetricComputation(Metric, abc.ABC): process_group (Optional[ProcessGroup]): the process group used for the communication. Will use the default process group if not specified. """ + _batch_window_buffers: Optional[Dict[str, WindowBuffer]] def __init__( @@ -126,11 +135,21 @@ def __init__( window_size: int, compute_on_all_ranks: bool = False, should_validate_update: bool = False, + fuse_state_tensors: bool = False, process_group: Optional[dist.ProcessGroup] = None, + fused_update_limit: int = 0, + allow_missing_label_with_zero_weight: bool = False, *args: Any, **kwargs: Any, ) -> None: - super().__init__(process_group=process_group, *args, **kwargs) + metric_init_signature = inspect.signature(Metric.__init__) + if "fuse_state_tensors" in metric_init_signature.parameters: + kwargs["fuse_state_tensors"] = fuse_state_tensors + super().__init__( + process_group=process_group, + *args, + **kwargs, + ) self._my_rank = my_rank self._n_tasks = n_tasks @@ -191,6 +210,7 @@ def _add_state( # Avoid pyre error assert isinstance(default, torch.Tensor) super().add_state(window_state_name, default.detach().clone(), **kwargs) + self._batch_window_buffers[window_state_name] = WindowBuffer( max_size=self._window_size, max_buffer_count=MAX_BUFFER_COUNT, @@ -217,6 +237,7 @@ def update( predictions: Optional[torch.Tensor], labels: torch.Tensor, weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], ) -> None: # pragma: no cover pass @@ -234,14 +255,26 @@ def pre_compute(self) -> None: return def compute(self) -> List[MetricComputationReport]: - if self._my_rank == 0 or self._compute_on_all_ranks: - return self._compute() - else: - return [] + with record_function(f"## {self.__class__.__name__}:compute ##"): + if self._my_rank == 0 or self._compute_on_all_ranks: + return self._compute() + else: + return [] def local_compute(self) -> List[MetricComputationReport]: return self._compute() + def reset(self) -> None: + super().reset() + if self._batch_window_buffers is not None: + self._batch_window_buffers = { + name: WindowBuffer( + max_size=self._window_size, + max_buffer_count=MAX_BUFFER_COUNT, + ) + for name in self._batch_window_buffers + } + class RecMetric(nn.Module, abc.ABC): r"""The main class template to implement a recommendation metric. @@ -280,6 +313,7 @@ class RecMetric(nn.Module, abc.ABC): tasks=DefaultTaskInfo, ) """ + _computation_class: Type[RecMetricComputation] _namespace: MetricNamespaceBase _metrics_computations: nn.ModuleList @@ -290,6 +324,8 @@ class RecMetric(nn.Module, abc.ABC): _update_buffers: Dict[str, List[RecModelOutput]] _default_weights: Dict[Tuple[int, ...], torch.Tensor] + _required_inputs: Set[str] + PREDICTIONS: str = "predictions" LABELS: str = "labels" WEIGHTS: str = "weights" @@ -306,12 +342,19 @@ def __init__( compute_on_all_ranks: bool = False, should_validate_update: bool = False, process_group: Optional[dist.ProcessGroup] = None, - **kwargs: Any, + **kwargs: Dict[str, Any], ) -> None: + torch._C._log_api_usage_once( + f"torchrec.metrics.rec_metric.{self.__class__.__name__}" + ) # TODO(stellaya): consider to inherit from TorchMetrics.Metric or # TorchMetrics.MetricCollection. if ( - compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION + compute_mode + in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ] and fused_update_limit > 0 ): raise ValueError( @@ -322,43 +365,81 @@ def __init__( self._my_rank = my_rank self._window_size = math.ceil(window_size / world_size) self._batch_size = batch_size + self._metrics_computations = nn.ModuleList() self._tasks = tasks self._compute_mode = compute_mode self._fused_update_limit = fused_update_limit self._should_validate_update = should_validate_update self._default_weights = {} + self._required_inputs = set() self._update_buffers = { self.PREDICTIONS: [], self.LABELS: [], self.WEIGHTS: [], } - if compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: - n_metrics = 1 + # pyre-fixme[8]: Attribute has type `bool`; used as `Union[bool, + # Dict[str, Any]]`. + self.enable_pt2_compile: bool = kwargs.get("enable_pt2_compile", False) + # we need to remove the enable_pt2_compile from kwargs to avoid Metric object being initialized with it + if "enable_pt2_compile" in kwargs: + del kwargs["enable_pt2_compile"] + + if self._window_size < self._batch_size: + raise ValueError( + f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}." + ) + + if compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: task_per_metric = len(self._tasks) self._tasks_iter = self._fused_tasks_iter else: - n_metrics = len(self._tasks) task_per_metric = 1 self._tasks_iter = self._unfused_tasks_iter - self._metrics_computations: nn.ModuleList = nn.ModuleList( - [ - # This Pyre error seems to be Pyre's bug as it can be inferred by mypy - # according to https://github.com/python/mypy/issues/3048. - # pyre-fixme[45]: Cannot instantiate abstract class `RecMetricCoputation`. - self._computation_class( - my_rank, - batch_size, - task_per_metric, - self._window_size, - compute_on_all_ranks, - self._should_validate_update, - process_group, - **kwargs, - ) - for _ in range(n_metrics) + for task_config in ( + [self._tasks] + if compute_mode + in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, ] - ) + else self._tasks + ): + # pyre-ignore + kwargs["fused_update_limit"] = fused_update_limit + # This Pyre error seems to be Pyre's bug as it can be inferred by mypy + # according to https://github.com/python/mypy/issues/3048. + # pyre-fixme[45]: Cannot instantiate abstract class `RecMetricCoputation`. + metric_computation = self._computation_class( + my_rank=my_rank, + batch_size=batch_size, + n_tasks=task_per_metric, + window_size=self._window_size, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=self._should_validate_update, + fuse_state_tensors=( + compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION + ), + process_group=process_group, + **{**kwargs, **self._get_task_kwargs(task_config)}, + ) + required_inputs = self._get_task_required_inputs(task_config) + + self._metrics_computations.append(metric_computation) + self._required_inputs.update(required_inputs) + + def _get_task_kwargs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Dict[str, Any]: + return {} + + def _get_task_required_inputs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Set[str]: + return set() # TODO(stellaya): Refactor the _[fused, unfused]_tasks_iter methods and replace the # compute_scope str input with an enum @@ -371,10 +452,10 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType: for task, metric_value, has_valid_update in zip( self._tasks, metric_report.value, - self._metrics_computations[0].has_valid_update - if self._should_validate_update - else itertools.repeat( - 1 + ( + self._metrics_computations[0].has_valid_update + if self._should_validate_update + else itertools.repeat(1) ), # has_valid_update > 0 means the update is valid ): # The attribute has_valid_update is a tensor whose length equals to the @@ -387,7 +468,7 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType: if has_valid_update > 0 else torch.zeros_like(metric_value) ) - yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value + yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType: for task, metric_computation in zip(self._tasks, self._metrics_computations): @@ -405,7 +486,7 @@ def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType: or metric_computation.has_valid_update[0] > 0 else torch.zeros_like(metric_report.value) ) - yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value + yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description def _fuse_update_buffers(self) -> Dict[str, RecModelOutput]: def fuse(outputs: List[RecModelOutput]) -> RecModelOutput: @@ -466,19 +547,46 @@ def _update( predictions: RecModelOutput, labels: RecModelOutput, weights: Optional[RecModelOutput], + **kwargs: Dict[str, Any], ) -> None: with torch.no_grad(): - if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: - assert isinstance(predictions, torch.Tensor) - # Reshape the predictions to size([len(self._tasks), self._batch_size]) - predictions = predictions.view(-1, self._batch_size) - assert isinstance(labels, torch.Tensor) - labels = labels.view(-1, self._batch_size) + if self._compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: + task_names = [task.name for task in self._tasks] + + if not isinstance(predictions, torch.Tensor): + predictions = torch.stack( + [predictions[task_name] for task_name in task_names] + ) + + if not isinstance(labels, torch.Tensor): + labels = torch.stack( + [labels[task_name] for task_name in task_names] + ) + if weights is not None and not isinstance(weights, torch.Tensor): + weights = torch.stack( + [weights[task_name] for task_name in task_names] + ) + + assert isinstance(predictions, torch.Tensor) and isinstance( + labels, torch.Tensor + ) + + predictions = ( + # Reshape the predictions to size([len(self._tasks), self._batch_size]) + predictions.view(len(self._tasks), -1) + if predictions.dim() == labels.dim() + # predictions.dim() == labels.dim() + 1 for multiclass models + else predictions.view(len(self._tasks), -1, predictions.size()[-1]) + ) + labels = labels.view(len(self._tasks), -1) if weights is None: weights = self._create_default_weights(predictions) else: assert isinstance(weights, torch.Tensor) - weights = weights.view(-1, self._batch_size) + weights = weights.view(len(self._tasks), -1) if self._should_validate_update: # has_valid_weights is a tensor of bool whose length equals to the number # of tasks. Each value in it is corresponding to whether the weights @@ -488,29 +596,71 @@ def _update( has_valid_weights = self._check_nonempty_weights(weights) if torch.any(has_valid_weights): self._metrics_computations[0].update( - predictions=predictions, labels=labels, weights=weights + predictions=predictions, + labels=labels, + weights=weights, + **kwargs, ) self._metrics_computations[0].has_valid_update.logical_or_( has_valid_weights ) else: self._metrics_computations[0].update( - predictions=predictions, labels=labels, weights=weights + predictions=predictions, + labels=labels, + weights=weights, + **kwargs, ) else: for task, metric_ in zip(self._tasks, self._metrics_computations): if task.name not in predictions: continue + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, typing.Tuple[typing.Any, + # ...]]` but got `str`. if torch.numel(predictions[task.name]) == 0: + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. assert torch.numel(labels[task.name]) == 0 + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. assert weights is None or torch.numel(weights[task.name]) == 0 continue - # Reshape the predictions to size([1, self._batch_size]) - task_predictions = predictions[task.name].view(1, -1) + task_predictions = ( + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. + predictions[task.name].view(1, -1) + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. + if predictions[task.name].dim() == labels[task.name].dim() + # predictions[task.name].dim() == labels[task.name].dim() + 1 for multiclass models + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. + else predictions[task.name].view( + 1, + -1, + predictions[ + task.name # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. + ].size()[-1], + ) + ) + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, typing.Tuple[typing.Any, + # ...]]` but got `str`. task_labels = labels[task.name].view(1, -1) if weights is None: task_weights = self._create_default_weights(task_predictions) else: + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, + # typing.Tuple[typing.Any, ...]]` but got `str`. task_weights = weights[task.name].view(1, -1) if self._should_validate_update: # has_valid_weights is a tensor with only 1 value corresponding to @@ -522,36 +672,55 @@ def _update( metric_.has_valid_update.logical_or_(has_valid_weights) else: continue + if "required_inputs" in kwargs: + # Expand scalars to match the shape of the predictions + kwargs["required_inputs"] = { + k: ( + v.view(task_labels.size()) + if v.numel() > 1 + else v.expand(task_labels.size()) + ) + for k, v in kwargs["required_inputs"].items() + } metric_.update( predictions=task_predictions, labels=task_labels, weights=task_weights, + **kwargs, ) + @pt2_compile_callable def update( self, *, predictions: RecModelOutput, labels: RecModelOutput, weights: Optional[RecModelOutput], + **kwargs: Dict[str, Any], ) -> None: - if self._fused_update_limit > 0: - self._update_buffers[self.PREDICTIONS].append(predictions) - self._update_buffers[self.LABELS].append(labels) - if weights is not None: - self._update_buffers[self.WEIGHTS].append(weights) - self._check_fused_update(force=False) - else: - self._update(predictions=predictions, labels=labels, weights=weights) + with record_function(f"## {self.__class__.__name__}:update ##"): + if self._fused_update_limit > 0: + self._update_buffers[self.PREDICTIONS].append(predictions) + self._update_buffers[self.LABELS].append(labels) + if weights is not None: + self._update_buffers[self.WEIGHTS].append(weights) + self._check_fused_update(force=False) + else: + self._update( + predictions=predictions, labels=labels, weights=weights, **kwargs + ) # The implementation of compute is very similar to local_compute, but compute overwrites # the abstract method compute in torchmetrics.Metric, which is wrapped by _wrap_compute + @pt2_compile_callable def compute(self) -> Dict[str, torch.Tensor]: self._check_fused_update(force=True) ret = {} - for task, metric_name, metric_value, prefix in self._tasks_iter(""): + for task, metric_name, metric_value, prefix, description in self._tasks_iter( + "" + ): metric_key = compose_metric_key( - self._namespace, task.name, metric_name, prefix + self._namespace, task.name, metric_name, prefix, description ) ret[metric_key] = metric_value return ret @@ -559,9 +728,11 @@ def compute(self) -> Dict[str, torch.Tensor]: def local_compute(self) -> Dict[str, torch.Tensor]: self._check_fused_update(force=True) ret = {} - for task, metric_name, metric_value, prefix in self._tasks_iter("local_"): + for task, metric_name, metric_value, prefix, description in self._tasks_iter( + "local_" + ): metric_key = compose_metric_key( - self._namespace, task.name, metric_name, prefix + self._namespace, task.name, metric_name, prefix, description ) ret[metric_key] = metric_value return ret @@ -588,9 +759,7 @@ def get_memory_usage(self) -> Dict[torch.Tensor, int]: while attributes_q: attribute = attributes_q.popleft() if isinstance(attribute, torch.Tensor): - tensor_map[attribute] = ( - attribute.size().numel() * attribute.element_size() - ) + tensor_map[attribute] = get_tensor_size_bytes(attribute) elif isinstance(attribute, WindowBuffer): attributes_q.extend(attribute.buffers) elif isinstance(attribute, Mapping): @@ -619,6 +788,9 @@ def state_dict( keep_vars=keep_vars, ) + def get_required_inputs(self) -> Set[str]: + return self._required_inputs + class RecMetricList(nn.Module): """ @@ -646,6 +818,7 @@ class RecMetricList(nn.Module): """ rec_metrics: nn.ModuleList + required_inputs: Optional[List[str]] def __init__(self, rec_metrics: List[RecMetric]) -> None: # TODO(stellaya): consider to inherit from TorchMetrics.MetricCollection. @@ -654,6 +827,14 @@ def __init__(self, rec_metrics: List[RecMetric]) -> None: super().__init__() self.rec_metrics = nn.ModuleList(rec_metrics) + self.required_inputs = ( + list( + set().union( + *[rec_metric.get_required_inputs() for rec_metric in rec_metrics] + ) + ) + or None + ) def __len__(self) -> int: return len(self.rec_metrics) @@ -661,15 +842,21 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> nn.Module: return self.rec_metrics[idx] + def get_required_inputs(self) -> Optional[List[str]]: + return self.required_inputs + def update( self, *, predictions: RecModelOutput, labels: RecModelOutput, weights: RecModelOutput, + **kwargs: Dict[str, Any], ) -> None: for metric in self.rec_metrics: - metric.update(predictions=predictions, labels=labels, weights=weights) + metric.update( + predictions=predictions, labels=labels, weights=weights, **kwargs + ) def compute(self) -> Dict[str, torch.Tensor]: ret = {} diff --git a/torchrec/metrics/recall.py b/torchrec/metrics/recall.py new file mode 100644 index 000000000..5031045c7 --- /dev/null +++ b/torchrec/metrics/recall.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +THRESHOLD = "threshold" + + +def compute_recall( + num_true_positives: torch.Tensor, num_false_negitives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_negitives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_negitives), + ) + + +def compute_true_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * ((predictions >= threshold) * labels), dim=-1) + + +def compute_false_neg_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + predictions = predictions.double() + return torch.sum(weights * ((predictions <= threshold) * labels), dim=-1) + + +def get_recall_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], + threshold: float = 0.5, +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(predictions) + return { + "true_pos_sum": compute_true_pos_sum(labels, predictions, weights, threshold), + "false_neg_sum": compute_false_neg_sum(labels, predictions, weights, threshold), + } + + +class RecallMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Recall. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + threshold (float): If provided, computes Recall metrics cutting off at + the specified threshold. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "true_pos_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_neg_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._threshold: float = threshold + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None: + raise RecMetricException( + "Inputs 'predictions' should not be None for RecallMetricComputation update" + ) + states = get_recall_states(labels, predictions, weights, self._threshold) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.RECALL, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_recall( + cast(torch.Tensor, self.true_pos_sum), + cast(torch.Tensor, self.false_neg_sum), + ), + ), + MetricComputationReport( + name=MetricName.RECALL, + metric_prefix=MetricPrefix.WINDOW, + value=compute_recall( + self.get_window_state("true_pos_sum"), + self.get_window_state("false_neg_sum"), + ), + ), + ] + return reports + + +class RecallMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALL + _computation_class: Type[RecMetricComputation] = RecallMetricComputation diff --git a/torchrec/metrics/recall_session.py b/torchrec/metrics/recall_session.py new file mode 100644 index 000000000..3733e472d --- /dev/null +++ b/torchrec/metrics/recall_session.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Any, cast, Dict, List, Optional, Set, Type, Union + +import torch +from torch import distributed as dist +from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecComputeMode, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +NUM_TRUE_POS = "num_true_pos" +NUM_FALSE_NEGATIVE = "num_false_neg" + + +def _validate_model_outputs( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + sessions: torch.Tensor, +) -> None: + # check if tensors are of the same shape + assert labels.dim() == 2 + assert labels.shape == predictions.shape + assert labels.shape == weights.shape + assert labels.shape == sessions.shape + + +def ranking_within_session( + predictions: torch.Tensor, + session: torch.Tensor, +) -> torch.Tensor: + # rank predictions that belong to the same session + + # Example: + # predictions = [1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8] + # sessions = [1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1] + # return = [0, 5, 3, 2, 1, 6, 4, 1, 0, 4, 3, 2] + n_tasks = predictions.size(0) + matching_session_id = session.view(-1, n_tasks) == session.view(n_tasks, -1) + predictions_relation = predictions.view(-1, n_tasks) >= predictions.view( + n_tasks, -1 + ) + relation_within_session = matching_session_id & predictions_relation + rank_within_session = torch.sum(matching_session_id, dim=-1) - torch.sum( + relation_within_session, dim=-1 + ) + return rank_within_session + + +def _calc_num_true_pos( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + # predictions are expected to be 0 or 1 integers. + num_true_pos = torch.sum(weights * labels * (predictions == 1).double(), dim=-1) + return num_true_pos + + +def _calc_num_false_neg( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + # predictions are expected to be 0 or 1 integers. + num_false_neg = torch.sum(weights * labels * (predictions == 0).double(), dim=-1) + return num_false_neg + + +def _calc_recall( + num_true_pos: torch.Tensor, num_false_neg: torch.Tensor +) -> torch.Tensor: + # if num_true_pos + num_false_neg == 0 then we set recall = NaN by default. + recall = torch.tensor([float("nan")]) + if (num_true_pos + num_false_neg).item() != 0: + recall = num_true_pos / (num_true_pos + num_false_neg) + else: + logger.warning( + "Recall = NaN. Likely, it means that there were no positive examples passed to the metric yet." + " Please, debug if you expect every batch to include positive examples." + ) + return recall + + +class RecallSessionMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Recall on session level. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + """ + + def __init__( + self, + *args: Any, + session_metric_def: SessionMetricDef, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._add_state( + NUM_TRUE_POS, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_FALSE_NEGATIVE, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.top_threshold: Optional[int] = session_metric_def.top_threshold + self.run_ranking_of_labels: bool = session_metric_def.run_ranking_of_labels + self.session_var_name: Optional[str] = session_metric_def.session_var_name + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + """ + Args: + predictions (torch.Tensor): tensor of size (n_task, n_examples) + labels (torch.Tensor): tensor of size (n_task, n_examples) + weights (torch.Tensor): tensor of size (n_task, n_examples) + session (torch.Tensor): Optional tensor of size (n_task, n_examples) that specifies the groups of + predictions/labels per batch. + """ + + if ( + "required_inputs" not in kwargs + or self.session_var_name not in kwargs["required_inputs"] + ): + raise RecMetricException( + "Need the {} input to update the session metric".format( + self.session_var_name + ) + ) + # pyre-ignore + session = kwargs["required_inputs"][self.session_var_name] + if predictions is None or weights is None or session is None: + raise RecMetricException( + "Inputs 'predictions', 'weights' and 'session' should not be None for RecallSessionMetricComputation update" + ) + _validate_model_outputs(labels, predictions, weights, session) + + predictions = predictions.double() + labels = labels.double() + weights = weights.double() + + num_samples = predictions.shape[-1] + for state_name, state_value in self.get_recall_states( + labels=labels, predictions=predictions, weights=weights, session=session + ).items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + + return [ + MetricComputationReport( + name=MetricName.RECALL_SESSION_LEVEL, + metric_prefix=MetricPrefix.LIFETIME, + value=_calc_recall( + num_true_pos=cast(torch.Tensor, getattr(self, NUM_TRUE_POS)), + num_false_neg=cast(torch.Tensor, getattr(self, NUM_FALSE_NEGATIVE)), + ), + ), + MetricComputationReport( + name=MetricName.RECALL_SESSION_LEVEL, + metric_prefix=MetricPrefix.WINDOW, + value=_calc_recall( + num_true_pos=self.get_window_state(NUM_TRUE_POS), + num_false_neg=self.get_window_state(NUM_FALSE_NEGATIVE), + ), + ), + ] + + def get_recall_states( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + session: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + + predictions_ranked = ranking_within_session(predictions, session) + # pyre-fixme[58]: `<` is not supported for operand types `Tensor` and + # `Optional[int]`. + predictions_labels = (predictions_ranked < self.top_threshold).to(torch.int32) + if self.run_ranking_of_labels: + labels_ranked = ranking_within_session(labels, session) + # pyre-fixme[58]: `<` is not supported for operand types `Tensor` and + # `Optional[int]`. + labels = (labels_ranked < self.top_threshold).to(torch.int32) + num_true_pos = _calc_num_true_pos(labels, predictions_labels, weights) + num_false_neg = _calc_num_false_neg(labels, predictions_labels, weights) + + return {NUM_TRUE_POS: num_true_pos, NUM_FALSE_NEGATIVE: num_false_neg} + + +class RecallSessionMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALL_SESSION_LEVEL + _computation_class: Type[RecMetricComputation] = RecallSessionMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Any, + ) -> None: + if compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: + raise RecMetricException( + "Fused computation is not supported for recall session-level metrics" + ) + + if fused_update_limit > 0: + raise RecMetricException( + "Fused update is not supported for recall session-level metrics" + ) + for task in tasks: + if task.session_metric_def is None: + raise RecMetricException( + "Please, specify the session metric definition" + ) + session_metric_def = task.session_metric_def + if session_metric_def.top_threshold is None: + raise RecMetricException("Please, specify the top threshold") + + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + process_group=process_group, + **kwargs, + ) + + def _get_task_kwargs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Dict[str, Any]: + if isinstance(task_config, list): + raise RecMetricException("Session metric can only take one task at a time") + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + + return {"session_metric_def": task_config.session_metric_def} + + def _get_task_required_inputs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Set[str]: + if isinstance(task_config, list): + raise RecMetricException("Session metric can only take one task at a time") + + if task_config.session_metric_def is None: + raise RecMetricException("Please, specify the session metric definition") + + return ( + {task_config.session_metric_def.session_var_name} + if task_config.session_metric_def.session_var_name + else set() + ) diff --git a/torchrec/metrics/scalar.py b/torchrec/metrics/scalar.py new file mode 100644 index 000000000..ccfaa5bd1 --- /dev/null +++ b/torchrec/metrics/scalar.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, List, Optional, Type + +import torch + +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, +) + + +class ScalarMetricComputation(RecMetricComputation): + """ + Metric that logs whatever value is given as the label. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="mean", + persistent=False, + ) + self._add_state( + "window_count", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="mean", + persistent=False, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + num_samples = labels.shape[0] + + states = { + "labels": labels.mean(dim=-1), + "window_count": torch.tensor([1.0]).to( + labels.device + ), # put window count on the correct device + } + for state_name, state_value in states.items(): + setattr(self, state_name, state_value) + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.SCALAR, + metric_prefix=MetricPrefix.LIFETIME, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Tensor, Module]`. + value=self.labels, + ), + MetricComputationReport( + name=MetricName.SCALAR, + metric_prefix=MetricPrefix.WINDOW, + # return the mean of the window state + value=self.get_window_state("labels") + / self.get_window_state("window_count"), + ), + ] + + +class ScalarMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SCALAR + _computation_class: Type[RecMetricComputation] = ScalarMetricComputation diff --git a/torchrec/metrics/segmented_ne.py b/torchrec/metrics/segmented_ne.py new file mode 100644 index 000000000..2ccf6ac2f --- /dev/null +++ b/torchrec/metrics/segmented_ne.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Any, Dict, List, Optional, Type + +import torch +from torch import distributed as dist +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +PREDICTIONS = "predictions" +LABELS = "labels" +WEIGHTS = "weights" + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +def compute_ne_helper( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def compute_logloss( + ce_sum: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + # we utilize tensor broadcasting for operations + labels_sum = pos_labels + neg_labels + labels_sum.clamp_(min=eta) + return ce_sum / labels_sum + + +def compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + num_groups: int, + eta: float, +) -> torch.Tensor: + # size should be (num_groups) + result_ne = torch.zeros(num_groups) + for group in range(num_groups): + mean_label = pos_labels[group] / weighted_num_samples[group] + ce_norm = _compute_cross_entropy_norm( + mean_label, pos_labels[group], neg_labels[group], eta + ) + ne = ce_sum[group] / ce_norm + result_ne[group] = ne + + # ne indexed by group - tensor size (num_groups) + return result_ne + + +def get_segemented_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + grouping_keys: torch.Tensor, + eta: float, + num_groups: int, +) -> Dict[str, torch.Tensor]: + groups = torch.unique(grouping_keys) + cross_entropy, weighted_num_samples, pos_labels, neg_labels = ( + torch.zeros(num_groups).to(labels.device), + torch.zeros(num_groups).to(labels.device), + torch.zeros(num_groups).to(labels.device), + torch.zeros(num_groups).to(labels.device), + ) + for group in groups: + group_mask = grouping_keys == group + + group_labels = labels[group_mask] + group_predictions = predictions[group_mask] + group_weights = weights[group_mask] + + ce_sum_group = torch.sum( + compute_cross_entropy( + labels=group_labels, + predictions=group_predictions, + weights=group_weights, + eta=eta, + ), + dim=-1, + ) + + weighted_num_samples_group = torch.sum(group_weights, dim=-1) + pos_labels_group = torch.sum(group_weights * group_labels, dim=-1) + neg_labels_group = torch.sum(group_weights * (1.0 - group_labels), dim=-1) + + cross_entropy[group] = ce_sum_group.item() + weighted_num_samples[group] = weighted_num_samples_group.item() + pos_labels[group] = pos_labels_group.item() + neg_labels[group] = neg_labels_group.item() + + # tensor size for each value is (num_groups) + return { + "cross_entropy_sum": cross_entropy, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + } + + +def _state_reduction_sum(state: torch.Tensor) -> torch.Tensor: + return state.sum(dim=0) + + +class SegmentedNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Segmented NE, i.e. Normalized Entropy - for boolean labels. + + Only binary labels are currently supported (0s, 1s), NE is computed for each label, NE across the whole model output + can be done through the normal NE metric. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE. + num_groups (int): number of groups to segment NE by. + grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64. + cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32. + """ + + def __init__( + self, + *args: Any, + include_logloss: bool = False, # TODO - include + num_groups: int = 1, + grouping_keys: str = "grouping_keys", + cast_keys_to_int: bool = False, + **kwargs: Any, + ) -> None: + self._include_logloss: bool = include_logloss + super().__init__(*args, **kwargs) + self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state + self._grouping_keys = grouping_keys + self._cast_keys_to_int = cast_keys_to_int + self._add_state( + "cross_entropy_sum", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros((self._n_tasks, num_groups), dtype=torch.double), + add_window_state=False, + dist_reduce_fx=_state_reduction_sum, + persistent=True, + ) + self.eta = 1e-12 + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update" + ) + elif ( + "required_inputs" not in kwargs + or kwargs["required_inputs"].get(self._grouping_keys) is None + ): + raise RecMetricException( + f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}" + ) + elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64: + if ( + self._cast_keys_to_int + and kwargs["required_inputs"][self._grouping_keys].dtype + == torch.float32 + ): + kwargs["required_inputs"][self._grouping_keys] = kwargs[ + "required_inputs" + ][self._grouping_keys].to(torch.int64) + else: + raise RecMetricException( + f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}." + ) + + grouping_keys = kwargs["required_inputs"][self._grouping_keys] + states = get_segemented_ne_states( + labels, + predictions, + weights, + grouping_keys, + eta=self.eta, + num_groups=self._num_groups, + ) + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + + def _compute(self) -> List[MetricComputationReport]: + reports = [] + computed_ne = compute_ne( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self.cross_entropy_sum[0], + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self.weighted_num_samples[0], + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self.pos_labels[0], + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self.neg_labels[0], + num_groups=self._num_groups, + eta=self.eta, + ) + + for group in range(self._num_groups): + reports.append( + MetricComputationReport( + name=MetricName.SEGMENTED_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=computed_ne[group], + description="_" + str(group), + ), + ) + + if self._include_logloss: + log_loss_groups = compute_logloss( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _Nes... + self.cross_entropy_sum[0], + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _Nes... + self.pos_labels[0], + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _Nes... + self.neg_labels[0], + eta=self.eta, + ) + + for group in range(self._num_groups): + reports.append( + MetricComputationReport( + name=MetricName.LOG_LOSS, + metric_prefix=MetricPrefix.LIFETIME, + value=log_loss_groups[group], + description="_" + str(group), + ) + ) + + return reports + + +class SegmentedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SEGMENTED_NE + _computation_class: Type[RecMetricComputation] = SegmentedNEMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + process_group: Optional[dist.ProcessGroup] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=compute_on_all_ranks, + should_validate_update=should_validate_update, + process_group=process_group, + **kwargs, + ) + if "grouping_keys" not in kwargs: + self._required_inputs.add("grouping_keys") + else: + # pyre-ignore[6] + self._required_inputs.add(kwargs["grouping_keys"]) + if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION: + logging.warning( + f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet " + "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect." + ) diff --git a/torchrec/metrics/serving_calibration.py b/torchrec/metrics/serving_calibration.py new file mode 100644 index 000000000..aa54b6317 --- /dev/null +++ b/torchrec/metrics/serving_calibration.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.calibration import compute_calibration, get_calibration_states +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +CALIBRATION_NUM = "calibration_num" +CALIBRATION_DENOM = "calibration_denom" +NUM_EXAMPLES = "num_examples" + + +class ServingCalibrationMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Calibration, which is the + ratio between the prediction and the labels (conversions). + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + CALIBRATION_NUM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + CALIBRATION_DENOM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=False, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update" + ) + num_samples = predictions.shape[-1] + for state_name, state_value in get_calibration_states( + labels, predictions, weights + ).items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + num_examples_delta = torch.count_nonzero(weights, dim=-1) + state_num_examples = getattr(self, NUM_EXAMPLES) + state_num_examples += num_examples_delta + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.CALIBRATION, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_calibration( + cast(torch.Tensor, self.calibration_num), + cast(torch.Tensor, self.calibration_denom), + ), + ), + MetricComputationReport( + name=MetricName.CALIBRATION, + metric_prefix=MetricPrefix.WINDOW, + value=compute_calibration( + self.get_window_state(CALIBRATION_NUM), + self.get_window_state(CALIBRATION_DENOM), + ), + ), + MetricComputationReport( + name=MetricName.TOTAL_EXAMPLES, + metric_prefix=MetricPrefix.DEFAULT, + value=cast(torch.Tensor, self.num_examples).detach(), + ), + ] + + +class ServingCalibrationMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SERVING_CALIBRATION + _computation_class: Type[RecMetricComputation] = ServingCalibrationMetricComputation diff --git a/torchrec/metrics/serving_ne.py b/torchrec/metrics/serving_ne.py new file mode 100644 index 000000000..37b868828 --- /dev/null +++ b/torchrec/metrics/serving_ne.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.ne import compute_ne, get_ne_states +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +NUM_EXAMPLES = "num_examples" + + +class ServingNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for NE over serving data only, + i.e., excluding data with weight=0. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + NUM_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=False, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + def _get_bucket_metric_states( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + ) -> Dict[str, torch.Tensor]: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for BucketNEMetricComputation update" + ) + + if labels.nelement() == 0: + return { + "cross_entropy_sum": torch.zeros(self._n_tasks, dtype=torch.double), + "weighted_num_samples": torch.zeros(self._n_tasks, dtype=torch.double), + "pos_labels": torch.zeros(self._n_tasks, dtype=torch.double), + "neg_labels": torch.zeros(self._n_tasks, dtype=torch.double), + } + + return get_ne_states( + labels=labels, + predictions=predictions, + weights=weights, + eta=self.eta, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for ServingNEMetricComputation update" + ) + + states = get_ne_states(labels, predictions, weights, self.eta) + + num_samples = labels.shape[-1] + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + num_examples_delta = torch.count_nonzero(weights, dim=-1) + state_num_examples = getattr(self, NUM_EXAMPLES) + state_num_examples += num_examples_delta + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.NE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_ne( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + ), + ), + MetricComputationReport( + name=MetricName.NE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_ne( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + ), + ), + MetricComputationReport( + name=MetricName.TOTAL_EXAMPLES, + metric_prefix=MetricPrefix.DEFAULT, + value=cast(torch.Tensor, self.num_examples).detach(), + ), + ] + + +class ServingNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SERVING_NE + _computation_class: Type[RecMetricComputation] = ServingNEMetricComputation diff --git a/torchrec/metrics/tensor_weighted_avg.py b/torchrec/metrics/tensor_weighted_avg.py new file mode 100644 index 000000000..2c582f4c0 --- /dev/null +++ b/torchrec/metrics/tensor_weighted_avg.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Set, Type, Union + +import torch +from torchrec.metrics.metrics_config import RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor: + return value_sum / num_samples + + +class TensorWeightedAvgMetricComputation(RecMetricComputation): + def __init__( + self, + *args: Any, + tensor_name: Optional[str] = None, + weighted: bool = True, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if tensor_name is None: + raise RecMetricException( + f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}" + ) + self.tensor_name: str = tensor_name + self.weighted: bool = weighted + self._add_state( + "weighted_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._description = description + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if ( + "required_inputs" not in kwargs + or self.tensor_name not in kwargs["required_inputs"] + ): + raise RecMetricException( + f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs" + ) + num_samples = labels.shape[0] + target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name]) + weights = cast(torch.Tensor, weights) + states = { + "weighted_sum": ( + target_tensor * weights if self.weighted else target_tensor + ).sum(dim=-1), + "weighted_num_samples": ( + weights.sum(dim=-1) + if self.weighted + else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device) + ), + } + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.WEIGHTED_AVG, + metric_prefix=MetricPrefix.LIFETIME, + value=get_mean( + cast(torch.Tensor, self.weighted_sum), + cast(torch.Tensor, self.weighted_num_samples), + ), + description=self._description, + ), + MetricComputationReport( + name=MetricName.WEIGHTED_AVG, + metric_prefix=MetricPrefix.WINDOW, + value=get_mean( + self.get_window_state("weighted_sum"), + self.get_window_state("weighted_num_samples"), + ), + description=self._description, + ), + ] + + +class TensorWeightedAvgMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG + _computation_class: Type[RecMetricComputation] = TensorWeightedAvgMetricComputation + + def __init__( + self, + # pyre-ignore Missing parameter annotation [2] + *args, + **kwargs: Dict[str, Any], + ) -> None: + + super().__init__(*args, **kwargs) + + def _get_task_kwargs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Dict[str, Any]: + if not isinstance(task_config, RecTaskInfo): + raise RecMetricException( + f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings." + ) + return { + "tensor_name": task_config.tensor_name, + "weighted": task_config.weighted, + } + + def _get_task_required_inputs( + self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] + ) -> Set[str]: + if not isinstance(task_config, RecTaskInfo): + raise RecMetricException( + f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings." + ) + required_inputs = set() + if task_config.tensor_name is not None: + required_inputs.add(task_config.tensor_name) + return required_inputs diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index c3a211401..0a1085195 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -5,18 +5,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc import os import random import tempfile import uuid -from typing import Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from unittest.mock import Mock, patch import torch import torch.distributed as dist import torch.distributed.launcher as pet +from torchrec.metrics.auc import AUCMetric +from torchrec.metrics.auprc import AUPRCMetric from torchrec.metrics.model_utils import parse_task_model_outputs +from torchrec.metrics.rauc import RAUCMetric from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo TestRecMetricOutput = Tuple[ @@ -38,15 +43,23 @@ def gen_test_batch( prediction_value: Optional[torch.Tensor] = None, weight_value: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, + n_classes: Optional[int] = None, + seed: Optional[int] = None, ) -> Dict[str, torch.Tensor]: + if seed is not None: + torch.manual_seed(seed) if label_value is not None: label = label_value else: - label = torch.randint(0, 2, (batch_size,)).double() + label = torch.randint(0, n_classes or 2, (batch_size,)).double() if prediction_value is not None: prediction = prediction_value else: - prediction = torch.rand(batch_size, dtype=torch.double) + prediction = ( + torch.rand(batch_size, dtype=torch.double) + if n_classes is None + else torch.rand(batch_size, n_classes, dtype=torch.double) + ) if weight_value is not None: weight = weight_value else: @@ -218,16 +231,31 @@ def rec_metric_value_test_helper( batch_window_size: int = BATCH_WINDOW_SIZE, is_time_dependent: bool = False, time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None, + n_classes: Optional[int] = None, + zero_weights: bool = False, + zero_labels: bool = False, + **kwargs: Any, ) -> Tuple[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], ...]]: tasks = gen_test_tasks(task_names) model_outs = [] for _ in range(nsteps): + weight_value: Optional[torch.Tensor] = None + if zero_weights: + weight_value = torch.zeros(batch_size) + + label_value: Optional[torch.Tensor] = None + if zero_labels: + label_value = torch.zeros(batch_size) + _model_outs = [ gen_test_batch( label_name=task.label_name, prediction_name=task.prediction_name, weight_name=task.weight_name, batch_size=batch_size, + n_classes=n_classes, + weight_value=weight_value, + label_value=label_value, ) for task in tasks ] @@ -238,8 +266,15 @@ def get_target_rec_metric_value( tasks: List[RecTaskInfo], timestamps: Optional[List[float]] = None, time_mock: Optional[Mock] = None, + **kwargs: Any, ) -> Dict[str, torch.Tensor]: + window_size = world_size * batch_size * batch_window_size + if n_classes: + kwargs["number_of_classes"] = n_classes + if zero_weights: + kwargs["allow_missing_label_with_zero_weight"] = True + target_metric_obj = target_clazz( world_size=world_size, my_rank=my_rank, @@ -250,12 +285,16 @@ def get_target_rec_metric_value( fused_update_limit=fused_update_limit, compute_on_all_ranks=compute_on_all_ranks, should_validate_update=should_validate_update, + **kwargs, ) for i in range(nsteps): - labels, predictions, weights = parse_task_model_outputs( + labels, predictions, weights, _ = parse_task_model_outputs( tasks, model_outs[i] ) - if target_compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: + if target_compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: labels = torch.stack(list(labels.values())) predictions = torch.stack(list(predictions.values())) weights = torch.stack(list(weights.values())) @@ -291,11 +330,11 @@ def get_test_rec_metric_value( time_dependent_target_clazz_path = time_dependent_metric[target_clazz] with patch(time_dependent_target_clazz_path + ".time.monotonic") as time_mock: result_metrics = get_target_rec_metric_value( - model_outs, tasks, timestamps, time_mock + model_outs, tasks, timestamps, time_mock, **kwargs ) test_metrics = get_test_rec_metric_value(model_outs, tasks, timestamps) else: - result_metrics = get_target_rec_metric_value(model_outs, tasks) + result_metrics = get_target_rec_metric_value(model_outs, tasks, **kwargs) test_metrics = get_test_rec_metric_value(model_outs, tasks) return result_metrics, test_metrics @@ -316,17 +355,190 @@ def get_launch_config(world_size: int, rdzv_endpoint: str) -> pet.LaunchConfig: ) +def rec_metric_gpu_sync_test_launcher( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + metric_name: str, + task_names: List[str], + fused_update_limit: int, + compute_on_all_ranks: bool, + should_validate_update: bool, + world_size: int, + entry_point: Callable[..., None], + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + **kwargs: Dict[str, Any], +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + lc = get_launch_config( + world_size=world_size, rdzv_endpoint=os.path.join(tmpdir, "rdzv") + ) + + # launch using torch elastic, launches for each rank + pet.elastic_launch(lc, entrypoint=entry_point)( + target_clazz, + target_compute_mode, + test_clazz, + task_names, + metric_name, + world_size, + fused_update_limit, + compute_on_all_ranks, + should_validate_update, + batch_size, + batch_window_size, + kwargs.get("n_classes", None), + ) + + +def sync_test_helper( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + test_clazz: Optional[Type[TestMetric]], + task_names: List[str], + metric_name: str, + world_size: int, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + batch_size: int = BATCH_SIZE, + batch_window_size: int = BATCH_WINDOW_SIZE, + n_classes: Optional[int] = None, + zero_weights: bool = False, + **kwargs: Dict[str, Any], +) -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="gloo", + world_size=world_size, + rank=rank, + ) + + tasks = gen_test_tasks(task_names) + + if n_classes: + # pyre-ignore[6]: Incompatible parameter type + kwargs["number_of_classes"] = n_classes + + auc = target_clazz( + world_size=world_size, + batch_size=batch_size, + my_rank=rank, + compute_on_all_ranks=compute_on_all_ranks, + tasks=tasks, + window_size=batch_window_size * world_size, + # pyre-ignore[6]: Incompatible parameter type + **kwargs, + ) + + weight_value: Optional[torch.Tensor] = None + + _model_outs = [ + gen_test_batch( + label_name=task.label_name, + prediction_name=task.prediction_name, + weight_name=task.weight_name, + batch_size=batch_size, + n_classes=n_classes, + weight_value=weight_value, + seed=42, # we set seed because of how test metric places tensors on ranks + ) + for task in tasks + ] + model_outs = [] + model_outs.append({k: v for d in _model_outs for k, v in d.items()}) + + # we send an uneven number of tensors to each rank to test that GPU sync works + if rank == 0: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + # check against test metric + test_metrics: TestRecMetricOutput = ({}, {}, {}, {}) + if test_clazz is not None: + # pyre-ignore[45]: Cannot instantiate abstract class `TestMetric`. + test_metric_obj = test_clazz(world_size, tasks) + # with how testmetric is setup we cannot do asymmertrical updates across ranks + # so we duplicate model_outs twice to match number of updates in aggregate + model_outs = model_outs * 2 + test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None) + + res = auc.compute() + + if rank == 0: + # Serving Calibration uses Calibration naming inconsistently + if metric_name == "serving_calibration": + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"{metric_name}-{task_names[0]}|window_calibration"], + ) + else: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"{metric_name}-{task_names[0]}|window_{metric_name}"], + ) + + # we also test the case where other rank has more tensors than rank 0 + auc.reset() + if rank == 0: + for _ in range(1): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + elif rank == 1: + for _ in range(3): + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_outs[0] + ) + auc.update(predictions=predictions, labels=labels, weights=weights) + + res = auc.compute() + + if rank == 0: + # Serving Calibration uses Calibration naming inconsistently + if metric_name == "serving_calibration": + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"{metric_name}-{task_names[0]}|window_calibration"], + ) + else: + assert torch.allclose( + test_metrics[1][task_names[0]], + res[f"{metric_name}-{task_names[0]}|window_{metric_name}"], + ) + + dist.destroy_process_group() + + def rec_metric_value_test_launcher( target_clazz: Type[RecMetric], target_compute_mode: RecComputeMode, test_clazz: Type[TestMetric], + metric_name: str, task_names: List[str], fused_update_limit: int, compute_on_all_ranks: bool, should_validate_update: bool, world_size: int, entry_point: Callable[..., None], + batch_window_size: int = BATCH_WINDOW_SIZE, test_nsteps: int = 1, + n_classes: Optional[int] = None, + zero_weights: bool = False, + zero_labels: bool = False, + **kwargs: Any, ) -> None: with tempfile.TemporaryDirectory() as tmpdir: lc = get_launch_config( @@ -348,14 +560,25 @@ def rec_metric_value_test_launcher( batch_size=32, nsteps=test_nsteps, batch_window_size=1, + n_classes=n_classes, + zero_weights=zero_weights, + zero_labels=zero_labels, + **kwargs, ) + pet.elastic_launch(lc, entrypoint=entry_point)( target_clazz, target_compute_mode, task_names, + test_clazz, + metric_name, fused_update_limit, compute_on_all_ranks, should_validate_update, + batch_window_size, + n_classes, + test_nsteps, + zero_weights, ) @@ -367,3 +590,83 @@ def rec_metric_accuracy_test_helper( world_size=world_size, rdzv_endpoint=os.path.join(tmpdir, "rdzv") ) pet.elastic_launch(lc, entrypoint=entry_point)() + + +def metric_test_helper( + target_clazz: Type[RecMetric], + target_compute_mode: RecComputeMode, + task_names: List[str], + test_clazz: Type[TestMetric], + metric_name: str, + fused_update_limit: int = 0, + compute_on_all_ranks: bool = False, + should_validate_update: bool = False, + batch_window_size: int = BATCH_WINDOW_SIZE, + n_classes: Optional[int] = None, + nsteps: int = 1, + zero_weights: bool = False, + is_time_dependent: bool = False, + time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None, + **kwargs: Any, +) -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="gloo", + world_size=world_size, + rank=rank, + ) + + target_metrics, test_metrics = rec_metric_value_test_helper( + target_clazz=target_clazz, + target_compute_mode=target_compute_mode, + test_clazz=test_clazz, + fused_update_limit=fused_update_limit, + compute_on_all_ranks=False, + should_validate_update=should_validate_update, + world_size=world_size, + my_rank=rank, + task_names=task_names, + batch_window_size=batch_window_size, + n_classes=n_classes, + nsteps=nsteps, + is_time_dependent=is_time_dependent, + time_dependent_metric=time_dependent_metric, + zero_weights=zero_weights, + **kwargs, + ) + + if rank == 0: + for name in task_names: + # we don't have lifetime metric for AUC due to OOM. + if ( + target_clazz != AUCMetric + and target_clazz != AUPRCMetric + and target_clazz != RAUCMetric + ): + assert torch.allclose( + target_metrics[ + f"{str(target_clazz._namespace)}-{name}|lifetime_{metric_name}" + ], + test_metrics[0][name], + ) + assert torch.allclose( + target_metrics[ + f"{str(target_clazz._namespace)}-{name}|local_lifetime_{metric_name}" + ], + test_metrics[2][name], + ) + assert torch.allclose( + target_metrics[ + f"{str(target_clazz._namespace)}-{name}|window_{metric_name}" + ], + test_metrics[1][name], + ) + + assert torch.allclose( + target_metrics[ + f"{str(target_clazz._namespace)}-{name}|local_window_{metric_name}" + ], + test_metrics[3][name], + ) + dist.destroy_process_group() diff --git a/torchrec/metrics/tests/test_accuracy.py b/torchrec/metrics/tests/test_accuracy.py new file mode 100644 index 000000000..fa46b3e87 --- /dev/null +++ b/torchrec/metrics/tests/test_accuracy.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.accuracy import AccuracyMetric, compute_accuracy +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestAccuracyMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + accuracy_sum = torch.sum(weights * ((predictions >= 0.5) == labels)) + return { + "accuracy_sum": accuracy_sum, + "weighted_num_samples": torch.sum(weights), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_accuracy( + states["accuracy_sum"], + states["weighted_num_samples"], + ) + + +class AccuracyMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = AccuracyMetric + task_name: str = "accuracy" + + def test_accuracy_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=AccuracyMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAccuracyMetric, + metric_name=AccuracyMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_accuracy_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=AccuracyMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestAccuracyMetric, + metric_name=AccuracyMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_accuracy_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=AccuracyMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestAccuracyMetric, + metric_name=AccuracyMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class AccuracyMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of accuracy in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.accuracy = AccuracyMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_acc_perfect(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[1] * 5000 + [0] * 10000 + [1] * 5000] + ) + + expected_accuracy = torch.tensor([1], dtype=torch.double) + self.accuracy.update(**self.batches) + actual_accuracy = self.accuracy.compute()[ + "accuracy-DefaultTask|window_accuracy" + ] + torch.allclose(expected_accuracy, actual_accuracy) + + def test_calc_acc_zero(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[0] * 5000 + [1] * 10000 + [0] * 5000] + ) + + expected_accuracy = torch.tensor([0], dtype=torch.double) + self.accuracy.update(**self.batches) + actual_accuracy = self.accuracy.compute()[ + "accuracy-DefaultTask|window_accuracy" + ] + torch.allclose(expected_accuracy, actual_accuracy) + + def test_calc_accuracy_balanced(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.ones([1, 20000]) + + expected_accuracy = torch.tensor([0.5], dtype=torch.double) + self.accuracy.update(**self.batches) + actual_accuracy = self.accuracy.compute()[ + "accuracy-DefaultTask|window_accuracy" + ] + torch.allclose(expected_accuracy, actual_accuracy) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Union[float, torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.3, 0.2, 0.5, 0.8, 0.7]]), + "threshold": 0.6, + "expected_accuracy": torch.tensor([0.28]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.6, + "expected_accuracy": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.1, + "expected_accuracy": torch.tensor([0.0]), + }, + ] + + +class ThresholdValueTest(unittest.TestCase): + """This set of tests verify the computation logic of accuracy with a modified threshold + in several cases that we know the computation results. + """ + + @no_grad() + def _test_accuracy_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_accuracy: torch.Tensor, + threshold: float, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + accuracy = AccuracyMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + threshold=threshold, # threshold is one of the kwargs + ) + accuracy.update(**inputs) + actual_accuracy = accuracy.compute() + + for task_id, task in enumerate(task_list): + cur_actual_accuracy = actual_accuracy[ + f"accuracy-{task.name}|window_accuracy" + ] + cur_expected_accuracy = expected_accuracy[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_accuracy, + cur_expected_accuracy, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_accuracy}, Expected: {cur_expected_accuracy}", + ) + + def test_accuracy(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_accuracy_helper( + **inputs # pyre-ignore, surpressing a type hint error + ) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + +class AccuracyGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AccuracyMetric + task_name: str = "accuracy" + + def test_sync_accuracy(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AccuracyMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAccuracyMetric, + metric_name=AccuracyGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index b040f8d26..36f389c86 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -5,18 +5,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os +# pyre-strict + import unittest -from typing import Dict, List, Type +from typing import Dict, Iterable, List, Optional, Type, Union import torch -import torch.distributed as dist +from torch import no_grad from torchrec.metrics.auc import AUCMetric from torchrec.metrics.metrics_config import DefaultTaskInfo -from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo +from torchrec.metrics.rec_metric import ( + RecComputeMode, + RecMetric, + RecMetricException, + RecTaskInfo, +) from torchrec.metrics.test_utils import ( - rec_metric_value_test_helper, + metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -56,9 +64,9 @@ def _aggregate( ) -> None: for k, v in new_states.items(): if k not in states: - states[k] = v.double().detach().clone() + states[k] = v.float().detach().clone() else: - states[k] = torch.cat([states[k], v.double()]) + states[k] = torch.cat([states[k], v.float()]) @staticmethod def _get_states( @@ -82,56 +90,18 @@ class AUCMetricTest(unittest.TestCase): clazz: Type[RecMetric] = AUCMetric task_name: str = "auc" - @staticmethod - def _test_auc( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - compute_on_all_ranks: bool = False, - should_validate_update: bool = False, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, - ) - - auc_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, - test_clazz=TestAUCMetric, - fused_update_limit=fused_update_limit, - compute_on_all_ranks=False, - should_validate_update=should_validate_update, - world_size=world_size, - my_rank=rank, - task_names=task_names, - ) - - if rank == 0: - for name in task_names: - assert torch.allclose( - auc_metrics[f"auc-{name}|window_auc"], test_metrics[1][name] - ), (auc_metrics[f"auc-{name}|window_auc"], test_metrics[1][name]) - assert torch.allclose( - auc_metrics[f"auc-{name}|local_window_auc"], test_metrics[3][name] - ), (auc_metrics[f"auc-{name}|local_window_auc"], test_metrics[3][name]) - dist.destroy_process_group() - def test_unfused_auc(self) -> None: rec_metric_value_test_launcher( target_clazz=AUCMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestAUCMetric, + metric_name=AUCMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_auc, + entry_point=metric_test_helper, ) def test_fused_auc(self) -> None: @@ -139,12 +109,34 @@ def test_fused_auc(self) -> None: target_clazz=AUCMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestAUCMetric, + metric_name=AUCMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_auc, + entry_point=metric_test_helper, + ) + + +class AUCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUCMetric + task_name: str = "auc" + + def test_sync_auc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUCMetric, + metric_name=AUCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, ) @@ -166,7 +158,7 @@ def setUp(self) -> None: self.auc = AUCMetric( world_size=1, my_rank=0, - batch_size=20000, + batch_size=100, tasks=[DefaultTaskInfo], ) @@ -179,7 +171,7 @@ def test_calc_auc_perfect(self) -> None: [[1] * 5000 + [0] * 10000 + [1] * 5000] ) - expected_auc = torch.tensor([1], dtype=torch.double) + expected_auc = torch.tensor([1], dtype=torch.float) self.auc.update(**self.batches) actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) @@ -193,7 +185,7 @@ def test_calc_auc_zero(self) -> None: [[0] * 5000 + [1] * 10000 + [0] * 5000] ) - expected_auc = torch.tensor([0], dtype=torch.double) + expected_auc = torch.tensor([0], dtype=torch.float) self.auc.update(**self.batches) actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) @@ -205,7 +197,266 @@ def test_calc_auc_balanced(self) -> None: self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) self.weights["DefaultTask"] = torch.ones([1, 20000]) - expected_auc = torch.tensor([0.5], dtype=torch.double) + expected_auc = torch.tensor([0.5], dtype=torch.float) self.auc.update(**self.batches) actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) + + def test_calc_uneven_updates(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + expected_auc = torch.tensor([0.4464], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) + self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) + self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) + + auc.update(**self.batches) + # second batch + self.labels["DefaultTask"] = torch.tensor([1, 1]) + self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) + self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) + + auc.update(**self.batches) + multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, multiple_batch) + + def test_window_size_auc(self) -> None: + # for determinisitc batches + torch.manual_seed(0) + + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=5, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + # init states, so we expect 3 (state tensors) * 4 bytes (float) + self.assertEqual(sum(auc.get_memory_usage().values()), 12) + + # bs = 5 + self.labels["DefaultTask"] = torch.rand(5) + self.predictions["DefaultTask"] = torch.rand(5) + self.weights["DefaultTask"] = torch.rand(5) + + for _ in range(1000): + auc.update(**self.batches) + + # check memory, window size is 100, so we have upperbound of memory to expect + # so with a 100 window size / tensors of size 5 = 20 tensors (per state) * 3 states * 20 bytes per tensor of size 5 = 1200 bytes + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + # with bs 5, we expect 20 tensors per state, so 60 tensors + self.assertEqual(len(auc.get_memory_usage().values()), 60) + + torch.allclose( + auc.compute()["auc-DefaultTask|window_auc"], + torch.tensor([0.4859], dtype=torch.float), + ) + + # test auc memory usage with window size equal to incoming batch + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + self.labels["DefaultTask"] = torch.rand(100) + self.predictions["DefaultTask"] = torch.rand(100) + self.weights["DefaultTask"] = torch.rand(100) + + for _ in range(10): + auc.update(**self.batches) + + # passing in batch size == window size, we expect for each state just one tensor of size 400, sum to 1200 as previous + self.assertEqual(sum(auc.get_memory_usage().values()), 1200) + self.assertEqual(len(auc.get_memory_usage().values()), 3) + + torch.allclose( + auc.compute()["auc-DefaultTask|window_auc"], + torch.tensor([0.4859], dtype=torch.float), + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_auc": torch.tensor([0.2419]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 1, 0, 0, 1]), + "expected_auc": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_auc": torch.tensor([0.0]), + }, + # all_scores_the_same + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[1] * 6]), + "grouping_keys": torch.tensor([1, 1, 1, 0, 0, 0]), + "expected_auc": torch.tensor([0.5]), + }, + # one_class_in_input + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 0, 0, 1, 0]), + "expected_auc": torch.tensor([0.5]), + }, + # one_group + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([1, 1, 1, 1, 1]), + "expected_auc": torch.tensor([0.4464]), + }, + # two tasks + { + "labels": torch.tensor([[1, 0, 0, 1, 0], [1, 1, 1, 1, 0]]), + "predictions": torch.tensor( + [ + [0.2281, 0.1051, 0.4885, 0.7740, 0.3097], + [0.4658, 0.3445, 0.6048, 0.6587, 0.5088], + ] + ), + "weights": torch.tensor( + [ + [0.6334, 0.6937, 0.6631, 0.5078, 0.3570], + [0.2637, 0.2479, 0.2697, 0.6500, 0.7583], + ] + ), + "grouping_keys": torch.tensor([0, 1, 0, 0, 1]), + "expected_auc": torch.tensor([0.4725, 0.25]), + }, + ] + + +class GroupedAUCValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of AUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_grouped_auc_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_auc: torch.Tensor, + grouping_keys: Optional[torch.Tensor] = None, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + if grouping_keys is not None: + inputs["required_inputs"] = {"grouping_keys": grouping_keys} + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + grouped_auc=True, + ) + auc.update(**inputs) + actual_auc = auc.compute() + + for task_id, task in enumerate(task_list): + cur_actual_auc = actual_auc[f"auc-{task.name}|window_grouped_auc"] + cur_expected_auc = expected_auc[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_auc, + cur_expected_auc, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_auc}, Expected: {cur_expected_auc}", + ) + + def test_grouped_auc(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_grouped_auc_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + def test_misconfigured_grouped_auc(self) -> None: + with self.assertRaises(RecMetricException): + self._test_grouped_auc_helper( + **{ + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + # no provided grouping_keys + "expected_auc": torch.tensor([0.2419]), + }, + ) + + def test_required_input_for_grouped_auc(self) -> None: + auc = AUCMetric( + world_size=1, + my_rank=0, + batch_size=1, + tasks=[ + RecTaskInfo( + name="Task:0", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + ], + # pyre-ignore + grouped_auc=True, + ) + + self.assertIn("grouping_keys", auc.get_required_inputs()) diff --git a/torchrec/metrics/tests/test_auprc.py b/torchrec/metrics/tests/test_auprc.py new file mode 100644 index 000000000..e7172639a --- /dev/null +++ b/torchrec/metrics/tests/test_auprc.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, List, Optional, Type, Union + +import torch +from torch import no_grad + +from torchrec.metrics.auprc import _compute_auprc_helper, AUPRCMetric +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.rec_metric import ( + RecComputeMode, + RecMetric, + RecMetricException, + RecTaskInfo, +) +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +def compute_auprc( + predictions: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + return _compute_auprc_helper(predictions, labels, weights) + + +class TestAUPRCMetric(TestMetric): + def __init__( + self, + world_size: int, + rec_tasks: List[RecTaskInfo], + ) -> None: + super().__init__( + world_size, + rec_tasks, + compute_lifetime_metric=False, + local_compute_lifetime_metric=False, + ) + + @staticmethod + def _aggregate( + states: Dict[str, torch.Tensor], new_states: Dict[str, torch.Tensor] + ) -> None: + for k, v in new_states.items(): + if k not in states: + states[k] = v.float().detach().clone() + else: + states[k] = torch.cat([states[k], v.float()]) + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return { + "predictions": predictions, + "weights": weights, + "labels": labels, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_auprc(states["predictions"], states["labels"], states["weights"]) + + +WORLD_SIZE = 4 + + +class AUPRCMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = AUPRCMetric + task_name: str = "auprc" + + def test_unfused_auprc(self) -> None: + rec_metric_value_test_launcher( + target_clazz=AUPRCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUPRCMetric, + metric_name=AUPRCMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_fused_auprc(self) -> None: + rec_metric_value_test_launcher( + target_clazz=AUPRCMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestAUPRCMetric, + metric_name=AUPRCMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class AUPRCMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of AUPRC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.auprc = AUPRCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_auprc_perfect(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[1] * 5000 + [0] * 10000 + [1] * 5000] + ) + + expected_auprc = torch.tensor([1], dtype=torch.float) + self.auprc.update(**self.batches) + actual_auprc = self.auprc.compute()["auprc-DefaultTask|window_auprc"] + torch.allclose(expected_auprc, actual_auprc) + + def test_calc_auprc_zero(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[0] * 5000 + [1] * 10000 + [0] * 5000] + ) + + expected_auprc = torch.tensor([0.3069], dtype=torch.float) + self.auprc.update(**self.batches) + actual_auprc = self.auprc.compute()["auprc-DefaultTask|window_auprc"] + torch.allclose(expected_auprc, actual_auprc) + + def test_calc_auprc_balanced(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.ones([1, 20000]) + + expected_auprc = torch.tensor([0.5], dtype=torch.float) + self.auprc.update(**self.batches) + actual_auprc = self.auprc.compute()["auprc-DefaultTask|window_auprc"] + torch.allclose(expected_auprc, actual_auprc) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_auprc": torch.tensor([0.5737]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 1, 0, 0, 1]), + "expected_auprc": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_auprc": torch.tensor([0.5833]), + }, + # all_scores_the_same + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[1] * 6]), + "grouping_keys": torch.tensor([1, 1, 1, 0, 0, 0]), + "expected_auprc": torch.tensor([0.5]), + }, + # one_class_in_input + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 0, 0, 1, 0]), + "expected_auprc": torch.tensor([1.0]), + }, + # one_group + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([1, 1, 1, 1, 1]), + "expected_auprc": torch.tensor([0.8291]), + }, + # two tasks + { + "labels": torch.tensor([[1, 0, 0, 1, 0], [1, 1, 1, 1, 0]]), + "predictions": torch.tensor( + [ + [0.2281, 0.1051, 0.4885, 0.7740, 0.3097], + [0.4658, 0.3445, 0.6048, 0.6587, 0.5088], + ] + ), + "weights": torch.tensor( + [ + [0.6334, 0.6937, 0.6631, 0.5078, 0.3570], + [0.2637, 0.2479, 0.2697, 0.6500, 0.7583], + ] + ), + "grouping_keys": torch.tensor([0, 1, 0, 0, 1]), + "expected_auprc": torch.tensor([0.3980, 0.6232]), + }, + ] + + +class GroupedAUPRCValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of AUPRC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_grouped_auprc_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_auprc: torch.Tensor, + grouping_keys: Optional[torch.Tensor] = None, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + if grouping_keys is not None: + inputs["required_inputs"] = {"grouping_keys": grouping_keys} + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + auprc = AUPRCMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + grouped_auprc=True, + ) + auprc.update(**inputs) + actual_auprc = auprc.compute() + + for task_id, task in enumerate(task_list): + cur_actual_auprc = actual_auprc[f"auprc-{task.name}|window_grouped_auprc"] + cur_expected_auprc = expected_auprc[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_auprc, + cur_expected_auprc, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_auprc}, Expected: {cur_expected_auprc}", + ) + + def test_grouped_auprc(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_grouped_auprc_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + def test_misconfigured_grouped_auprc(self) -> None: + with self.assertRaises(RecMetricException): + self._test_grouped_auprc_helper( + **{ + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + # no provided grouping_keys + "expected_auprc": torch.tensor([0.8291]), + }, + ) + + def test_required_input_for_grouped_auprc(self) -> None: + auprc = AUPRCMetric( + world_size=1, + my_rank=0, + batch_size=1, + tasks=[ + RecTaskInfo( + name="Task:0", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + ], + # pyre-ignore + grouped_auprc=True, + ) + + self.assertIn("grouping_keys", auprc.get_required_inputs()) + + +class AUPRCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUPRCMetric + task_name: str = "auprc" + + def test_sync_auprc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUPRCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUPRCMetric, + metric_name=AUPRCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_cali_free_ne.py b/torchrec/metrics/tests/test_cali_free_ne.py new file mode 100644 index 000000000..328dd7931 --- /dev/null +++ b/torchrec/metrics/tests/test_cali_free_ne.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.cali_free_ne import ( + CaliFreeNEMetric, + compute_cali_free_ne, + compute_cross_entropy, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestCaliFreeNEMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestCaliFreeNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + weighted_num_samples = torch.sum(weights) + pos_labels = torch.sum(weights * labels) + neg_labels = torch.sum(weights * (1.0 - labels)) + weighted_sum_predictions = torch.sum(weights * predictions) + return { + "cross_entropy_sum": cross_entropy_sum, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + "num_samples": torch.tensor(labels.size()).long(), + "weighted_sum_predictions": weighted_sum_predictions, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + allow_missing_label_with_zero_weight = False + if not states["weighted_num_samples"].all(): + allow_missing_label_with_zero_weight = True + + return compute_cali_free_ne( + states["cross_entropy_sum"], + states["weighted_num_samples"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + weighted_sum_predictions=states["weighted_sum_predictions"], + eta=TestCaliFreeNEMetric.eta, + allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight, + ) + + +class CaliFreeNEMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = CaliFreeNEMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "cali_free_ne" + + def test_cali_free_ne_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_cali_free_ne_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_cali_free_ne_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_cali_free_ne_update_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + def test_cali_free_ne_zero_weights(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + zero_weights=True, + ) + + +class CaliFreeNEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CaliFreeNEMetric + task_name: str = "cali_free_ne" + + def test_sync_cali_free_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_calibration.py b/torchrec/metrics/tests/test_calibration.py index b9c271f30..6a2304485 100644 --- a/torchrec/metrics/tests/test_calibration.py +++ b/torchrec/metrics/tests/test_calibration.py @@ -5,17 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os +# pyre-strict + import unittest -from typing import Dict, List, Type +from typing import Dict, Type import torch -import torch.distributed as dist from torchrec.metrics.calibration import CalibrationMetric from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( - rec_metric_value_test_helper, + metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -50,79 +52,65 @@ class CalibrationMetricTest(unittest.TestCase): clazz: Type[RecMetric] = CalibrationMetric task_name: str = "calibration" - @staticmethod - def _test_calibration( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - compute_on_all_ranks: bool = False, - should_validate_update: bool = False, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, - ) - - calibration_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, + def test_calibration_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CalibrationMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestCalibrationMetric, - fused_update_limit=fused_update_limit, + metric_name=CalibrationMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, compute_on_all_ranks=False, - should_validate_update=should_validate_update, - world_size=world_size, - my_rank=rank, - task_names=task_names, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, ) - if rank == 0: - for name in task_names: - assert torch.allclose( - calibration_metrics[f"calibration-{name}|lifetime_calibration"], - test_metrics[0][name], - ) - assert torch.allclose( - calibration_metrics[f"calibration-{name}|window_calibration"], - test_metrics[1][name], - ) - assert torch.allclose( - calibration_metrics[ - f"calibration-{name}|local_lifetime_calibration" - ], - test_metrics[2][name], - ) - assert torch.allclose( - calibration_metrics[f"calibration-{name}|local_window_calibration"], - test_metrics[3][name], - ) - dist.destroy_process_group() - - def test_unfused_calibration(self) -> None: + def test_calibration_fused_tasks(self) -> None: rec_metric_value_test_launcher( target_clazz=CalibrationMetric, - target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestCalibrationMetric, + metric_name=CalibrationMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_calibration, + entry_point=metric_test_helper, ) - def test_fused_calibration(self) -> None: + def test_calibration_fused_tasks_and_states(self) -> None: rec_metric_value_test_launcher( target_clazz=CalibrationMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, test_clazz=TestCalibrationMetric, + metric_name=CalibrationMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_calibration, + entry_point=metric_test_helper, + ) + + +class CalibrationGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CalibrationMetric + task_name: str = "calibration" + + def test_sync_calibration(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CalibrationMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCalibrationMetric, + metric_name=CalibrationGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, ) diff --git a/torchrec/metrics/tests/test_calibration_with_recalibration.py b/torchrec/metrics/tests/test_calibration_with_recalibration.py new file mode 100644 index 000000000..693f0a8ca --- /dev/null +++ b/torchrec/metrics/tests/test_calibration_with_recalibration.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.metrics.calibration_with_recalibration import ( + RecalibratedCalibrationMetric, +) +from torchrec.metrics.metrics_config import DefaultTaskInfo + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +def generate_model_output() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0]]), + "expected_recalibrated_calibration": torch.tensor([0.0837]), + } + + +class RecalibratedCalibrationMetricMetricTest(unittest.TestCase): + def setUp(self) -> None: + self.calibration_with_recalibration = RecalibratedCalibrationMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + # pyre-ignore[6] + recalibration_coefficient=0.1, + ) + + def test_calibration_with_recalibration(self) -> None: + model_output = generate_model_output() + self.calibration_with_recalibration.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + ) + metric = self.calibration_with_recalibration.compute() + actual_metric = metric[ + f"recalibrated_calibration-{DefaultTaskInfo.name}|lifetime_calibration" + ] + expected_metric = model_output["expected_recalibrated_calibration"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) diff --git a/torchrec/metrics/tests/test_ctr.py b/torchrec/metrics/tests/test_ctr.py index f14bfaa90..efd45752c 100644 --- a/torchrec/metrics/tests/test_ctr.py +++ b/torchrec/metrics/tests/test_ctr.py @@ -5,17 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os +# pyre-strict + import unittest -from typing import Dict, List, Type +from typing import Dict, Type import torch -import torch.distributed as dist from torchrec.metrics.ctr import CTRMetric from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( - rec_metric_value_test_helper, + metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -44,73 +46,65 @@ class CTRMetricTest(unittest.TestCase): clazz: Type[RecMetric] = CTRMetric task_name: str = "ctr" - @staticmethod - def _test_ctr( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - compute_on_all_ranks: bool = False, - should_validate_update: bool = False, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, - ) - - ctr_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, + def test_ctr_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CTRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestCTRMetric, - fused_update_limit=fused_update_limit, + metric_name=CTRMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, compute_on_all_ranks=False, - should_validate_update=should_validate_update, - world_size=world_size, - my_rank=rank, - task_names=task_names, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, ) - if rank == 0: - for name in task_names: - assert torch.allclose( - ctr_metrics[f"ctr-{name}|lifetime_ctr"], test_metrics[0][name] - ) - assert torch.allclose( - ctr_metrics[f"ctr-{name}|window_ctr"], test_metrics[1][name] - ) - assert torch.allclose( - ctr_metrics[f"ctr-{name}|local_lifetime_ctr"], test_metrics[2][name] - ) - assert torch.allclose( - ctr_metrics[f"ctr-{name}|local_window_ctr"], test_metrics[3][name] - ) - dist.destroy_process_group() - - def test_unfused_ctr(self) -> None: + def test_ctr_fused_tasks(self) -> None: rec_metric_value_test_launcher( target_clazz=CTRMetric, - target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestCTRMetric, + metric_name=CTRMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ctr, + entry_point=metric_test_helper, ) - def test_fused_ctr(self) -> None: + def test_ctr_fused_tasks_and_states(self) -> None: rec_metric_value_test_launcher( target_clazz=CTRMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, test_clazz=TestCTRMetric, + metric_name=CTRMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ctr, + entry_point=metric_test_helper, + ) + + +class CTRGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CTRMetric + task_name: str = "ctr" + + def test_sync_ctr(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CTRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCTRMetric, + metric_name=CTRGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, ) diff --git a/torchrec/metrics/tests/test_gauc.py b/torchrec/metrics/tests/test_gauc.py new file mode 100644 index 000000000..513988cff --- /dev/null +++ b/torchrec/metrics/tests/test_gauc.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from typing import Dict + +import torch +from torchrec.metrics.gauc import compute_gauc_3d, compute_window_auc, GAUCMetric +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.test_utils import TestMetric + + +class TestGAUCMetric(TestMetric): + + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + gauc_res = compute_gauc_3d(predictions, labels, weights) + return { + "auc_sum": gauc_res["auc_sum"], + "num_samples": gauc_res["num_samples"], + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_window_auc( + states["auc_sum"], + states["num_samples"], + ) + + +class GAUCMetricValueTest(unittest.TestCase): + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.num_candidates = None + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": self.weights, + } + self.gauc = GAUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_gauc_simple(self) -> None: + self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]]) + self.labels["DefaultTask"] = torch.tensor([[1, 0, 1, 1, 0]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]]) + self.num_candidates = torch.tensor([3, 2]) + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": self.weights, + } + + expected_gauc = torch.tensor([0.75], dtype=torch.double) + expected_num_samples = torch.tensor([2], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) + + def test_calc_gauc_hard(self) -> None: + self.predictions["DefaultTask"] = torch.tensor( + [[0.3, 0.9, 0.1, 0.8, 0.2, 0.8, 0.7, 0.6, 0.5, 0.5]] + ) + self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 1, 0]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + self.num_candidates = torch.tensor([2, 3, 3, 2]) + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": self.weights, + } + + expected_gauc = torch.tensor([0.25], dtype=torch.double) + expected_num_samples = torch.tensor([2], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) + + def test_calc_gauc_all_0_labels(self) -> None: + self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]]) + self.labels["DefaultTask"] = torch.tensor([[0, 0, 0, 0, 0]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]]) + self.num_candidates = torch.tensor([3, 2]) + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": None, + } + + expected_gauc = torch.tensor([0.5], dtype=torch.double) + expected_num_samples = torch.tensor([0], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) + + def test_calc_gauc_all_1_labels(self) -> None: + self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]]) + self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]]) + self.num_candidates = torch.tensor([3, 2]) + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": None, + } + + expected_gauc = torch.tensor([0.5], dtype=torch.double) + expected_num_samples = torch.tensor([0], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) + + def test_calc_gauc_identical_predictions(self) -> None: + self.predictions["DefaultTask"] = torch.tensor([[0.8, 0.8, 0.8, 0.8, 0.8]]) + self.labels["DefaultTask"] = torch.tensor([[1, 1, 0, 1, 0]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]]) + self.num_candidates = torch.tensor([3, 2]) + self.weights = None + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": None, + } + + expected_gauc = torch.tensor([0.5], dtype=torch.double) + expected_num_samples = torch.tensor([0], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) + + def test_calc_gauc_weighted(self) -> None: + self.predictions["DefaultTask"] = torch.tensor( + [[0.3, 0.9, 0.1, 0.8, 0.2, 0.8, 0.7, 0.6, 0.5, 0.5]] + ) + self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 1, 0]]) + self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 1, 1, 1, 0, 1, 1]]) + self.num_candidates = torch.tensor([2, 3, 3, 2]) + self.batches = { + "predictions": self.predictions, + "labels": self.labels, + "num_candidates": self.num_candidates, + "weights": self.weights, + } + + expected_gauc = torch.tensor([0.5], dtype=torch.double) + expected_num_samples = torch.tensor([2], dtype=torch.double) + self.gauc.update(**self.batches) + gauc_res = self.gauc.compute() + actual_gauc, num_effective_samples = ( + gauc_res["gauc-DefaultTask|window_gauc"], + gauc_res["gauc-DefaultTask|window_gauc_num_samples"], + ) + if not torch.allclose(expected_num_samples, num_effective_samples): + raise ValueError( + "actual num sample {} is not equal to expected num sample {}".format( + num_effective_samples, expected_num_samples + ) + ) + if not torch.allclose(expected_gauc, actual_gauc): + raise ValueError( + "actual auc {} is not equal to expected auc {}".format( + actual_gauc, expected_gauc + ) + ) diff --git a/torchrec/metrics/tests/test_gpu.py b/torchrec/metrics/tests/test_gpu.py index e970c13a3..4fa7a08e3 100644 --- a/torchrec/metrics/tests/test_gpu.py +++ b/torchrec/metrics/tests/test_gpu.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import unittest @@ -26,7 +28,7 @@ class TestGPU(unittest.TestCase): @unittest.skipIf(_CUDA_UNAVAILABLE, "Test needs to run on GPU") def test_auc_reset(self) -> None: - batch_size = 128 + batch_size = 64 auc = AUCMetric( world_size=1, my_rank=0, diff --git a/torchrec/metrics/tests/test_hindsight_target_pr.py b/torchrec/metrics/tests/test_hindsight_target_pr.py new file mode 100644 index 000000000..5cc9e406d --- /dev/null +++ b/torchrec/metrics/tests/test_hindsight_target_pr.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.hindsight_target_pr import ( + compute_precision, + compute_recall, + compute_threshold_idx, + HindsightTargetPRMetric, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_value_test_launcher, + TestMetric, +) + + +WORLD_SIZE = 4 +THRESHOLD_GRANULARITY = 1000 + + +class TestHindsightTargetPRMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return torch.Tensor(threshold_idx) + + +class TestHindsightTargetPrecisionMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return compute_precision( + states["true_pos_sum"][threshold_idx], + states["false_pos_sum"][threshold_idx], + ) + + +class TestHindsightTargetRecallMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fn_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + fn_sum[i] = torch.sum(weights * ((predictions <= threshold) * labels), -1) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + "false_neg_sum": fn_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return compute_recall( + states["true_pos_sum"][threshold_idx], + states["false_neg_sum"][threshold_idx], + ) + + +# Fused tests are not supported for this metric. +class TestHindsightTargetPRMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = HindsightTargetPRMetric + pr_task_name: str = "hindsight_target_pr" + precision_task_name: str = "hindsight_target_precision" + recall_task_name: str = "hindsight_target_recall" + + def test_hindsight_target_precision_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=HindsightTargetPRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestHindsightTargetPrecisionMetric, + metric_name=TestHindsightTargetPRMetricTest.precision_task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_hindsight_target_recall_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=HindsightTargetPRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestHindsightTargetRecallMetric, + metric_name=TestHindsightTargetPRMetricTest.recall_task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) diff --git a/torchrec/metrics/tests/test_mae.py b/torchrec/metrics/tests/test_mae.py new file mode 100644 index 000000000..7f7737e45 --- /dev/null +++ b/torchrec/metrics/tests/test_mae.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.mae import compute_mae, MAEMetric +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +class TestMAEMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + error_sum = torch.sum(weights * torch.abs(labels - predictions)) + return { + "error_sum": error_sum, + "weighted_num_samples": torch.sum(weights), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_mae( + states["error_sum"], + states["weighted_num_samples"], + ) + + +WORLD_SIZE = 4 + + +class MAEMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = MAEMetric + task_name: str = "mae" + + def test_mae_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MAEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMAEMetric, + metric_name="mae", + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_mae_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MAEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestMAEMetric, + metric_name="mae", + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_mae_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MAEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestMAEMetric, + metric_name="mae", + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class MAEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = MAEMetric + task_name: str = "mae" + + def test_sync_mae(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=MAEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMAEMetric, + metric_name=MAEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index 33826e9bb..ca8f67e13 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -5,13 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import dataclasses import logging import os import tempfile import unittest -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, patch import torch @@ -27,6 +29,7 @@ ) from torchrec.metrics.metrics_config import ( _DEFAULT_WINDOW_SIZE, + BatchSizeStage, DefaultMetricsConfig, DefaultTaskInfo, EmptyMetricsConfig, @@ -79,10 +82,12 @@ def __init__( memory_usage_limit_mb=memory_usage_limit_mb, ) - def _update_rec_metrics(self, model_out: Dict[str, torch.Tensor]) -> None: + def _update_rec_metrics( + self, model_out: Dict[str, torch.Tensor], **kwargs: Any + ) -> None: if isinstance(model_out, MagicMock): return - labels, predictions, weights = parse_task_model_outputs( + labels, predictions, weights, _ = parse_task_model_outputs( self.rec_tasks, model_out ) self.rec_metrics.update(predictions=predictions, labels=labels, weights=weights) @@ -349,131 +354,6 @@ def test_initial_states_rank0_checkpointing(self) -> None: lc, entrypoint=self._run_trainer_initial_states_checkpointing )() - def test_empty_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = EmptyMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - self.assertEqual(metric_module.get_memory_usage(), 0) - - def test_ne_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = DefaultMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # Default NEMetric's dtype is - # float64 (8 bytes) * 16 tensors of size 1 = 128 bytes - # Tensors in NeMetricComputation: - # 8 in _default, 8 specific attributes: 4 attributes, 4 window - self.assertEqual(metric_module.get_memory_usage(), 128) - metric_module.update(gen_test_batch(128)) - self.assertEqual(metric_module.get_memory_usage(), 160) - - def test_calibration_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = dataclasses.replace( - DefaultMetricsConfig, - rec_metrics={ - RecMetricEnum.CALIBRATION: RecMetricDef( - rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE - ) - }, - ) - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # Default calibration metric dtype is - # float64 (8 bytes) * 8 tensors, size 1 = 64 bytes - # Tensors in CalibrationMetricComputation: - # 4 in _default, 4 specific attributes: 2 attribute, 2 window - self.assertEqual(metric_module.get_memory_usage(), 64) - metric_module.update(gen_test_batch(128)) - self.assertEqual(metric_module.get_memory_usage(), 80) - - def test_auc_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = dataclasses.replace( - DefaultMetricsConfig, - rec_metrics={ - RecMetricEnum.AUC: RecMetricDef( - rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE - ) - }, - ) - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - # 3 (tensors) * 8 (double) - self.assertEqual(metric_module.get_memory_usage(), 24) - metric_module.update(gen_test_batch(128)) - # 24 (initial states) + 3 (tensors) * 128 (batch_size) * 8 (double) - self.assertEqual(metric_module.get_memory_usage(), 3096) - - def test_check_memory_usage(self) -> None: - mock_optimizer = MockOptimizer() - config = DefaultMetricsConfig - metric_module = generate_metric_module( - TestMetricModule, - metrics_config=config, - batch_size=128, - world_size=64, - my_rank=0, - state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer}, - device=torch.device("cpu"), - ) - metric_module.update(gen_test_batch(128)) - with patch("torchrec.metrics.metric_module.logger") as logger_mock: - # Memory usage is fine. - metric_module.memory_usage_mb_avg = 160 / (10**6) - metric_module.check_memory_usage(1000) - self.assertEqual(metric_module.oom_count, 0) - self.assertEqual(logger_mock.warning.call_count, 0) - - # OOM but memory usage does not exceed avg. - metric_module.memory_usage_limit_mb = 0.000001 - metric_module.memory_usage_mb_avg = 160 / (10**6) - metric_module.check_memory_usage(1000) - self.assertEqual(metric_module.oom_count, 1) - self.assertEqual(logger_mock.warning.call_count, 1) - - # OOM and memory usage exceed avg but warmup is not over. - metric_module.memory_usage_mb_avg = 160 / (10**6) / 10 - metric_module.check_memory_usage(2) - self.assertEqual(metric_module.oom_count, 2) - self.assertEqual(logger_mock.warning.call_count, 2) - - # OOM and memory usage exceed avg and warmup is over. - metric_module.memory_usage_mb_avg = 160 / (10**6) / 1.25 - metric_module.check_memory_usage(1002) - self.assertEqual(metric_module.oom_count, 3) - self.assertEqual(logger_mock.warning.call_count, 4) - def test_should_compute(self) -> None: metric_module = generate_metric_module( TestMetricModule, @@ -513,6 +393,7 @@ def _test_adjust_compute_interval( ) mock_time.time = MagicMock(return_value=0.0) + # pyre-fixme[53]: Captured variable `batch` is not annotated. def _train(metric_module: RecMetricModule) -> float: for _ in range(metric_module.compute_interval_steps): metric_module.update(batch) @@ -656,3 +537,69 @@ def test_adjust_compute_interval_1_30(self) -> None: min_interval=1.0, max_interval=30.0, ) + + def test_save_and_load_state_dict(self) -> None: + # Test without batch_size_stages + metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + ) + metric_module.update(gen_test_batch(128)) + + state_dict_without_bss = metric_module.state_dict() + # Make sure state loading works and doesn't throw an error + metric_module.load_state_dict(state_dict_without_bss) + # Make sure num_batch in the throughput module is not in state_dict + self.assertFalse("throughput_metric.num_batch" in state_dict_without_bss) + + # Test with batch_size_stages + metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + batch_size_stages=[BatchSizeStage(256, 100), BatchSizeStage(512, None)], + ) + + # Update metric 100 times + for _ in range(100): + metric_module.update(gen_test_batch(128)) + + # Simulate a checkpoint save + state_dict = metric_module.state_dict() + # Make sure num_batch is updated correctly to 100 + self.assertEqual(state_dict["throughput_metric.num_batch"], 100) + + # Simulate a checkpoint load + metric_module.load_state_dict(state_dict) + # Make sure num_batch is correctly restored + throughput_metric = metric_module.throughput_metric + self.assertIsNotNone(throughput_metric) + self.assertEqual(throughput_metric._num_batch, 100) + # Make sure num_batch is correctly synchronized + self.assertEqual(throughput_metric._num_batch, 100) + + # Load the same checkpoint into a module that doesn't use BSS + + no_bss_metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + batch_size_stages=None, + ) + + no_bss_metric_module.load_state_dict(state_dict) + # Make sure num_batch wasn't created on the throughput module (and no exception was thrown above) + self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch")) diff --git a/torchrec/metrics/tests/test_metrics_namespace.py b/torchrec/metrics/tests/test_metrics_namespace.py index b85850524..b7788a306 100644 --- a/torchrec/metrics/tests/test_metrics_namespace.py +++ b/torchrec/metrics/tests/test_metrics_namespace.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import re import unittest diff --git a/torchrec/metrics/tests/test_mse.py b/torchrec/metrics/tests/test_mse.py index 24a1e11ce..2e4dec541 100644 --- a/torchrec/metrics/tests/test_mse.py +++ b/torchrec/metrics/tests/test_mse.py @@ -5,17 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os +# pyre-strict + import unittest -from typing import Dict, List, Type +from typing import Dict, Type import torch -import torch.distributed as dist from torchrec.metrics.mse import compute_mse, compute_rmse, MSEMetric from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( - rec_metric_value_test_helper, + metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -66,136 +68,109 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: class MSEMetricTest(unittest.TestCase): clazz: Type[RecMetric] = MSEMetric task_name: str = "mse" + rmse_task_name: str = "rmse" - @staticmethod - def _test_mse( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - compute_on_all_ranks: bool = False, - should_validate_update: bool = False, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, - ) - - mse_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, - test_clazz=TestMSEMetric, - fused_update_limit=fused_update_limit, - compute_on_all_ranks=False, - should_validate_update=should_validate_update, - world_size=world_size, - my_rank=rank, - task_names=task_names, - ) - - if rank == 0: - for name in task_names: - assert torch.allclose( - mse_metrics[f"mse-{name}|lifetime_mse"], test_metrics[0][name] - ) - assert torch.allclose( - mse_metrics[f"mse-{name}|window_mse"], test_metrics[1][name] - ) - assert torch.allclose( - mse_metrics[f"mse-{name}|local_lifetime_mse"], test_metrics[2][name] - ) - assert torch.allclose( - mse_metrics[f"mse-{name}|local_window_mse"], test_metrics[3][name] - ) - dist.destroy_process_group() - - def test_unfused_mse(self) -> None: + def test_mse_unfused(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestMSEMetric, + metric_name=MSEMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_mse, + entry_point=metric_test_helper, ) - def test_fused_mse(self) -> None: + def test_mse_fused_tasks(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestMSEMetric, + metric_name=MSEMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_mse, + entry_point=metric_test_helper, ) - @staticmethod - def _test_rmse( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, + def test_mse_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MSEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestMSEMetric, + metric_name=MSEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, ) - mse_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, + def test_rmse_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MSEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestRMSEMetric, - fused_update_limit=fused_update_limit, + metric_name=MSEMetricTest.rmse_task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, - world_size=world_size, - my_rank=rank, - task_names=task_names, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, ) - if rank == 0: - for name in task_names: - assert torch.allclose( - mse_metrics[f"mse-{name}|lifetime_rmse"], test_metrics[0][name] - ) - assert torch.allclose( - mse_metrics[f"mse-{name}|window_rmse"], test_metrics[1][name] - ) - - def test_unfused_rmse(self) -> None: + def test_rmse_fused_tasks(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, - target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestRMSEMetric, + metric_name=MSEMetricTest.rmse_task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_mse, + entry_point=metric_test_helper, ) - def test_fused_rmse(self) -> None: + def test_rmse_fused_tasks_and_states(self) -> None: rec_metric_value_test_launcher( target_clazz=MSEMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, test_clazz=TestRMSEMetric, + metric_name=MSEMetricTest.rmse_task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_mse, + entry_point=metric_test_helper, + ) + + +class MSEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = MSEMetric + task_name: str = "mse" + + def test_sync_mse(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=MSEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMSEMetric, + metric_name=MSEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, ) diff --git a/torchrec/metrics/tests/test_multiclass_recall.py b/torchrec/metrics/tests/test_multiclass_recall.py new file mode 100644 index 000000000..d0c736b69 --- /dev/null +++ b/torchrec/metrics/tests/test_multiclass_recall.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.multiclass_recall import ( + compute_multiclass_recall_at_k, + get_multiclass_recall_states, + MulticlassRecallMetric, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + +N_CLASSES = 4 +WORLD_SIZE = 4 + + +class TestMulticlassRecallMetric(TestMetric): + n_classes: int = N_CLASSES + + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + states = get_multiclass_recall_states( + predictions, labels, weights, TestMulticlassRecallMetric.n_classes + ) + return states + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_multiclass_recall_at_k( + states["tp_at_k"], + states["total_weights"], + ) + + +class MulticlassRecallMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = MulticlassRecallMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "multiclass_recall" + + def test_multiclass_recall_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MulticlassRecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMulticlassRecallMetric, + metric_name=MulticlassRecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + n_classes=N_CLASSES, + ) + + def test_multiclass_recall_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MulticlassRecallMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestMulticlassRecallMetric, + metric_name=MulticlassRecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + n_classes=N_CLASSES, + ) + + def test_multiclass_recall_update_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=MulticlassRecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMulticlassRecallMetric, + metric_name=MulticlassRecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + n_classes=N_CLASSES, + ) + + rec_metric_value_test_launcher( + target_clazz=MulticlassRecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMulticlassRecallMetric, + metric_name=MulticlassRecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + n_classes=N_CLASSES, + ) + + +class MulticlassRecallGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = MulticlassRecallMetric + task_name: str = "multiclass_recall" + + def test_sync_multiclass_recall(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=MulticlassRecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMulticlassRecallMetric, + metric_name=MulticlassRecallGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + # pyre-ignore[6] Incompatible parameter type + n_classes=N_CLASSES, + ) diff --git a/torchrec/metrics/tests/test_ndcg.py b/torchrec/metrics/tests/test_ndcg.py new file mode 100644 index 000000000..9948e9a02 --- /dev/null +++ b/torchrec/metrics/tests/test_ndcg.py @@ -0,0 +1,745 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from dataclasses import replace +from typing import Any, Dict, List + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo, RecComputeMode + +from torchrec.metrics.ndcg import NDCGMetric, SESSION_KEY +from torchrec.metrics.test_utils import RecTaskInfo + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + +DefaultTaskInfo0 = RecTaskInfo( + name="DefaultTask0", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo1 = RecTaskInfo( + name="DefaultTask1", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo2 = RecTaskInfo( + name="DefaultTask2", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + + +def get_test_case_single_session_within_batch() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), + "session_ids": torch.tensor([[1, 1, 1, 1, 1]]), + "labels": torch.tensor([[0.0, 1.0, 0.0, 0.0, 2.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 2.0]]), + "expected_ndcg_exp": torch.tensor([0.1103]), + "expected_ndcg_non_exp": torch.tensor([0.1522]), + } + + +def get_test_case_multiple_sessions_within_batch() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3]]), + "session_ids": torch.tensor([[1, 1, 1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0]]), + "expected_ndcg_exp": torch.tensor([0.6748]), + "expected_ndcg_non_exp": torch.tensor([0.6463]), + } + + +def get_test_case_all_labels_zero() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3]]), + "session_ids": torch.tensor([[1, 1, 1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0]]), + "expected_ndcg_exp": torch.tensor([2.5]), + "expected_ndcg_non_exp": torch.tensor([2.5]), + } + + +def get_test_case_another_multiple_sessions_within_batch() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.5, 0.3, 0.4, 0.2, 0.1]]), + "session_ids": torch.tensor([[1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0, 1.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]), + "expected_ndcg_exp": torch.tensor([(0.3066 + 0.0803) / 2]), + "expected_ndcg_non_exp": torch.tensor([(0.3066 + 0.0803) / 2]), + } + + +def get_test_case_at_k() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.5, 0.3, 0.4, 0.2, 0.1]]), + "session_ids": torch.tensor([[1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0, 1.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]), + "expected_ndcg_exp": torch.tensor([(0.6131 + 0.3869) / 2]), + "expected_ndcg_non_exp": torch.tensor([(0.6131 + 0.3869) / 2]), + } + + +def get_test_case_remove_single_length_sessions() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.5, 0.3, 0.4, 0.2, 0.1]]), + "session_ids": torch.tensor([[1, 1, 1, 2, 3, 4]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0, 1.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]), + "expected_ndcg_exp": torch.tensor([0.3066]), + "expected_ndcg_non_exp": torch.tensor([0.3066]), + } + + +def get_test_case_negative_task() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor([[0.9, 0.5, 0.7, 0.6, 0.8, 0.9]]), + "session_ids": torch.tensor([[1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[0.0, 1.0, 0.0, 0.0, 1.0, 0.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]), + "expected_ndcg_exp": torch.tensor([(0.3066 + 0.0803) / 2]), + "expected_ndcg_non_exp": torch.tensor([(0.3066 + 0.0803) / 2]), + } + + +def get_test_case_scale_by_weights_tensor() -> Dict[str, torch.Tensor]: + """ + For this test case, + predictions * weights = [0.1, 0, 0, 0.4, 0.0, 0.0] + labels * weights = [1, 0, 0, 1, 0, 0] + So NDCG going to be perfect for both sessions. + """ + return { + "predictions": torch.tensor([[0.1, 0.5, 0.3, 0.4, 0.2, 0.1]]), + "session_ids": torch.tensor([[1, 1, 1, 2, 2, 2]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0, 1.0]]), + "weights": torch.tensor([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0]]), + "expected_ndcg_exp": torch.tensor([(1.0 + 1.0) / 2]), + "expected_ndcg_non_exp": torch.tensor([(1.0 + 1.0) / 2]), + } + + +def get_test_case_multitask() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor( + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + ] + ), + "session_ids": torch.tensor( + [ + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + ] + ), + "labels": torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + ] + ), + "weights": torch.tensor( + [ + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + ] + ), + "expected_ndcg_exp": torch.tensor([0.6748, 0.6748, 0.6748]), + "expected_ndcg_non_exp": torch.tensor([0.6463, 0.6463, 0.6463]), + } + + +class NDCGMetricValueTest(unittest.TestCase): + def generate_metric( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo] = [DefaultTaskInfo], + exponential_gain: bool = False, + session_key: str = SESSION_KEY, + k: int = -1, + remove_single_length_sessions: bool = False, + scale_by_weights_tensor: bool = False, + report_ndcg_as_decreasing_curve: bool = True, + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + **kwargs: Dict[str, Any], + ) -> NDCGMetric: + return NDCGMetric( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + # pyre-ignore[6] + session_key=session_key, + # pyre-ignore[6] + exponential_gain=exponential_gain, + # pyre-ignore[6] + remove_single_length_sessions=remove_single_length_sessions, + # pyre-ignore[6] + scale_by_weights_tensor=scale_by_weights_tensor, + # pyre-ignore[6] + report_ndcg_as_decreasing_curve=report_ndcg_as_decreasing_curve, + # pyre-ignore[6] + k=k, + compute_mode=compute_mode, + # pyre-ignore[6] + **kwargs, + ) + + def test_single_session_non_exp(self) -> None: + """ + Test single session in a update. + """ + model_output = get_test_case_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_single_session_exp(self) -> None: + """ + Test single session in a update for exponential metric. + """ + model_output = get_test_case_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=True, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multiple_sessions_non_exp(self) -> None: + """ + Test multiple sessions in a single update. + """ + model_output = get_test_case_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multiple_sessions_exp(self) -> None: + model_output = get_test_case_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=True, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_negative_sessions(self) -> None: + """ + Test sessions where all labels are 0. + """ + model_output = get_test_case_all_labels_zero() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_negative_sessions_exp(self) -> None: + """ + Test sessions where all labels are 0, for exponential gain. + """ + model_output = get_test_case_all_labels_zero() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=True, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_another_multiple_sessions(self) -> None: + """ + Test another multiple sessions in a single update. + """ + model_output = get_test_case_another_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_another_multiple_sessions_exp(self) -> None: + """ + Test another multiple sessions in a single update, for exponential gain. + """ + model_output = get_test_case_another_multiple_sessions_within_batch() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=True, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_at_k(self) -> None: + """ + Test NDCG @ K. + """ + model_output = get_test_case_at_k() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + k=2, + ) + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_remove_single_length_sessions(self) -> None: + """ + Test NDCG with removing single length sessions. + """ + model_output = get_test_case_remove_single_length_sessions() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + remove_single_length_sessions=True, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_apply_negative_task_mask(self) -> None: + """ + Test NDCG with apply negative task mask. + """ + model_output = get_test_case_negative_task() + TempTaskInfo = replace(DefaultTaskInfo, is_negative_task=True) + + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[TempTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_case_report_as_increasing_ndcg_and_scale_by_weights_tensor(self) -> None: + """ + Test NDCG with reporting as increasing NDCG and scaling by weights tensor correctly. + """ + model_output = get_test_case_scale_by_weights_tensor() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + exponential_gain=False, + session_key=SESSION_KEY, + remove_single_length_sessions=True, + scale_by_weights_tensor=True, + report_ndcg_as_decreasing_curve=False, + ) + + metric.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + required_inputs={SESSION_KEY: model_output["session_ids"][0]}, + ) + + output = metric.compute() + actual_metric = output[f"ndcg-{DefaultTaskInfo.name}|lifetime_ndcg"] + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multitask_non_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=False, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multitask_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=True, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multitask_exp_fused_tasks_and_states(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=True, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) diff --git a/torchrec/metrics/tests/test_ne.py b/torchrec/metrics/tests/test_ne.py index 2264ea5e7..4a5a5359d 100644 --- a/torchrec/metrics/tests/test_ne.py +++ b/torchrec/metrics/tests/test_ne.py @@ -5,18 +5,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os +# pyre-strict + import unittest from functools import partial, update_wrapper -from typing import Callable, Dict, List, Type +from typing import Callable, Dict, Type import torch -import torch.distributed as dist -from torchrec.metrics.ne import compute_cross_entropy, compute_ne, NEMetric +from torchrec.metrics.ne import ( + compute_cross_entropy, + compute_logloss, + compute_ne, + NEMetric, +) from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( - rec_metric_value_test_helper, + metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -48,12 +55,46 @@ def _get_states( @staticmethod def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + allow_missing_label_with_zero_weight = False + if not states["weighted_num_samples"].all(): + allow_missing_label_with_zero_weight = True + return compute_ne( states["cross_entropy_sum"], states["weighted_num_samples"], pos_labels=states["pos_labels"], neg_labels=states["neg_labels"], eta=TestNEMetric.eta, + allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight, + ) + + +class TestLoglossMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + pos_labels = torch.sum(weights * labels, dim=-1) + neg_labels = torch.sum(weights * (1.0 - labels), dim=-1) + return { + "cross_entropy_sum": cross_entropy_sum, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_logloss( + states["cross_entropy_sum"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + eta=TestLoglossMetric.eta, ) @@ -62,84 +103,46 @@ class NEMetricTest(unittest.TestCase): target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION task_name: str = "ne" - @staticmethod - def _test_ne( - target_clazz: Type[RecMetric], - target_compute_mode: RecComputeMode, - task_names: List[str], - fused_update_limit: int = 0, - compute_on_all_ranks: bool = False, - should_validate_update: bool = False, - batch_window_size: int = 5, - ) -> None: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group( - backend="gloo", - world_size=world_size, - rank=rank, - ) - - ne_metrics, test_metrics = rec_metric_value_test_helper( - target_clazz=target_clazz, - target_compute_mode=target_compute_mode, - test_clazz=TestNEMetric, - fused_update_limit=fused_update_limit, - compute_on_all_ranks=False, - should_validate_update=should_validate_update, - world_size=world_size, - my_rank=rank, - task_names=task_names, - batch_window_size=batch_window_size, - ) - - if rank == 0: - for name in task_names: - assert torch.allclose( - ne_metrics[f"ne-{name}|lifetime_ne"], test_metrics[0][name] - ) - assert torch.allclose( - ne_metrics[f"ne-{name}|window_ne"], test_metrics[1][name] - ) - assert torch.allclose( - ne_metrics[f"ne-{name}|local_lifetime_ne"], test_metrics[2][name] - ) - assert torch.allclose( - ne_metrics[f"ne-{name}|local_window_ne"], test_metrics[3][name] - ) - dist.destroy_process_group() - - _test_ne_large_window_size: Callable[..., None] = partial( - # pyre-fixme[16]: `Callable` has no attribute `__func__`. - _test_ne.__func__, - batch_window_size=10, - ) - update_wrapper(_test_ne_large_window_size, _test_ne.__func__) - def test_ne_unfused(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ne, + entry_point=metric_test_helper, ) - def test_ne_fused(self) -> None: + def test_ne_fused_tasks(self) -> None: rec_metric_value_test_launcher( target_clazz=NEMetric, target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ne, + entry_point=metric_test_helper, + ) + + def test_ne_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, ) def test_ne_update_fused(self) -> None: @@ -147,48 +150,136 @@ def test_ne_update_fused(self) -> None: target_clazz=NEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, task_names=["t1", "t2", "t3"], fused_update_limit=5, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ne, + entry_point=metric_test_helper, ) rec_metric_value_test_launcher( target_clazz=NEMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + def test_ne_zero_weights(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=NEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + zero_weights=True, + ) + + _logloss_metric_test_helper: Callable[..., None] = partial( + metric_test_helper, include_logloss=True + ) + update_wrapper(_logloss_metric_test_helper, metric_test_helper) + + def test_logloss_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + metric_name="logloss", + test_clazz=TestLoglossMetric, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._logloss_metric_test_helper, + ) + + def test_logloss_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + metric_name="logloss", + test_clazz=TestLoglossMetric, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._logloss_metric_test_helper, + ) + + def test_logloss_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + metric_name="logloss", + test_clazz=TestLoglossMetric, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._logloss_metric_test_helper, + ) + + def test_logloss_update_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + metric_name="logloss", + test_clazz=TestLoglossMetric, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._logloss_metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + metric_name="logloss", + test_clazz=TestLoglossMetric, task_names=["t1", "t2", "t3"], fused_update_limit=100, compute_on_all_ranks=False, should_validate_update=False, world_size=WORLD_SIZE, - entry_point=self._test_ne_large_window_size, - ) - - # TODO(stellaya): support the usage of fused_tasks_computation and - # fused_update for the same RecMetric - # rec_metric_value_test_launcher( - # target_clazz=NEMetric, - # target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, - # test_clazz=TestNEMetric, - # task_names=["t1", "t2", "t3"], - # fused_update_limit=5, - # compute_on_all_ranks=False, - # should_validate_update=False, - # world_size=WORLD_SIZE, - # entry_point=self._test_ne, - # ) - - # rec_metric_value_test_launcher( - # target_clazz=NEMetric, - # target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, - # test_clazz=TestNEMetric, - # task_names=["t1", "t2", "t3"], - # fused_update_limit=100, - # compute_on_all_ranks=False, - # should_validate_update=False, - # world_size=WORLD_SIZE, - # entry_point=self._test_ne_large_window_size, - # ) + entry_point=self._logloss_metric_test_helper, + batch_window_size=10, + ) + + +class NEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = NEMetric + task_name: str = "ne" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=NEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_ne_positive.py b/torchrec/metrics/tests/test_ne_positive.py new file mode 100644 index 000000000..d4487cae7 --- /dev/null +++ b/torchrec/metrics/tests/test_ne_positive.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.ne_positive import NEPositiveMetric + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +def generate_model_output() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor([[0.8, 0.2, 0.3, 0.6, 0.5]]), + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1, 2, 1, 2, 1]]), + "expected_ne_positive": torch.tensor([0.4054]), + } + + +class NEPositiveValueTest(unittest.TestCase): + """ + This set of tests verify the computation logic of AUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.ne_positive = NEPositiveMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + ) + + def test_ne_positive(self) -> None: + model_output = generate_model_output() + self.ne_positive.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + ) + metric = self.ne_positive.compute() + print(metric) + actual_metric = metric[ + f"ne_positive-{DefaultTaskInfo.name}|lifetime_ne_positive" + ] + expected_metric = model_output["expected_ne_positive"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) diff --git a/torchrec/metrics/tests/test_ne_with_recalibration.py b/torchrec/metrics/tests/test_ne_with_recalibration.py new file mode 100644 index 000000000..66275df78 --- /dev/null +++ b/torchrec/metrics/tests/test_ne_with_recalibration.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +def generate_model_output() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), + "labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0]]), + "expected_recalibrated_ne": torch.tensor([2.8214]), + } + + +class RecalibratedNEMetricMetricTest(unittest.TestCase): + def setUp(self) -> None: + self.ne_with_recalibration = RecalibratedNEMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + # pyre-ignore[6] + recalibration_coefficient=0.1, + ) + + def test_ne_with_recalibration(self) -> None: + model_output = generate_model_output() + self.ne_with_recalibration.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + ) + metric = self.ne_with_recalibration.compute() + actual_metric = metric[f"recalibrated_ne-{DefaultTaskInfo.name}|lifetime_ne"] + expected_metric = model_output["expected_recalibrated_ne"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) diff --git a/torchrec/metrics/tests/test_precision.py b/torchrec/metrics/tests/test_precision.py new file mode 100644 index 000000000..8a58485f6 --- /dev/null +++ b/torchrec/metrics/tests/test_precision.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.precision import compute_precision, PrecisionMetric +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestPrecisionMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + true_pos_sum = torch.sum(weights * ((predictions >= 0.5) * labels)) + false_pos_sum = torch.sum(weights * ((predictions >= 0.5) * (1 - labels))) + return { + "true_pos_sum": true_pos_sum, + "false_pos_sum": false_pos_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_precision( + states["true_pos_sum"], + states["false_pos_sum"], + ) + + +class PrecisionMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = PrecisionMetric + task_name: str = "precision" + + def test_precision_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=PrecisionMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestPrecisionMetric, + metric_name=PrecisionMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=PrecisionMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestPrecisionMetric, + metric_name=PrecisionMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=PrecisionMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestPrecisionMetric, + metric_name=PrecisionMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class PrecisionMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of precision in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.precision = PrecisionMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_acc_perfect(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[1] * 5000 + [0] * 10000 + [1] * 5000] + ) + + expected_precision = torch.tensor([1], dtype=torch.double) + self.precision.update(**self.batches) + actual_precision = self.precision.compute()[ + "precision-DefaultTask|window_precision" + ] + torch.allclose(expected_precision, actual_precision) + + def test_calc_acc_zero(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[0] * 5000 + [1] * 10000 + [0] * 5000] + ) + + expected_precision = torch.tensor([0], dtype=torch.double) + self.precision.update(**self.batches) + actual_precision = self.precision.compute()[ + "precision-DefaultTask|window_precision" + ] + torch.allclose(expected_precision, actual_precision) + + def test_calc_precision_balanced(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.ones([1, 20000]) + + expected_precision = torch.tensor([0.5], dtype=torch.double) + self.precision.update(**self.batches) + actual_precision = self.precision.compute()[ + "precision-DefaultTask|window_precision" + ] + torch.allclose(expected_precision, actual_precision) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Union[float, torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.3, 0.2, 0.5, 0.8, 0.7]]), + "threshold": 0.6, + "expected_precision": torch.tensor([0.5]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.6, + "expected_precision": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.1, + "expected_precision": torch.tensor([0.0]), + }, + ] + + +class ThresholdValueTest(unittest.TestCase): + """This set of tests verify the computation logic of precision with a modified threshold + in several cases that we know the computation results. + """ + + @no_grad() + def _test_precision_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_precision: torch.Tensor, + threshold: float, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + precision = PrecisionMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + threshold=threshold, # threshold is one of the kwargs + ) + precision.update(**inputs) + actual_precision = precision.compute() + + for task_id, task in enumerate(task_list): + cur_actual_precision = actual_precision[ + f"precision-{task.name}|window_precision" + ] + cur_expected_precision = expected_precision[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_precision, + cur_expected_precision, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_precision}, Expected: {cur_expected_precision}", + ) + + def test_precision(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_precision_helper( + **inputs # pyre-ignore, surpressing a type hint error + ) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + +class PrecisionGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = PrecisionMetric + task_name: str = "precision" + + def test_sync_precision(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=PrecisionMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestPrecisionMetric, + metric_name=PrecisionGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_precision_session.py b/torchrec/metrics/tests/test_precision_session.py new file mode 100644 index 000000000..41b8874d3 --- /dev/null +++ b/torchrec/metrics/tests/test_precision_session.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Optional, Union + +import torch +from torch import no_grad + +from torchrec.metrics.metrics_config import ( + RecComputeMode, + RecTaskInfo, + SessionMetricDef, +) +from torchrec.metrics.precision_session import PrecisionSessionMetric +from torchrec.metrics.rec_metric import RecMetricException + + +def generate_model_output_test1() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor( + [[0.9, 0.1, 0.2, 0.3, 0.9, 0.9, 0.0, 0.9, 0.1, 0.4, 0.9, 0.1]] + ), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_precision": torch.tensor([0.5]), + } + + +def generate_model_output_test2() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor( + [[1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]] + ), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_precision": torch.tensor([0.5]), + } + + +def generate_model_output_with_no_positive_examples() -> ( + Dict[str, torch._tensor.Tensor] +): + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor([[0.0] * 12]), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_precision": torch.tensor([0.0]), + } + + +def generate_model_output_with_no_positive_predictions() -> ( + Dict[str, torch._tensor.Tensor] +): + return { + "predictions": torch.tensor([[float("nan")] * 12]), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor( + [[1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]] + ), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_precision": torch.tensor([float("nan")]), + } + + +class PrecisionSessionValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of Precision in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_precision_session_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + session: torch.Tensor, + expected_precision: torch.Tensor, + run_ranking_of_labels: bool = False, + precision_metric: Optional[PrecisionSessionMetric] = None, + ) -> PrecisionSessionMetric: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + session_metric_def=SessionMetricDef( + session_var_name="session", + top_threshold=3, + run_ranking_of_labels=run_ranking_of_labels, + ), + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + kwargs = {"required_inputs": {"session": session}} + + if precision_metric is None: + precision_metric = PrecisionSessionMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + precision_metric.update( + predictions=inputs["predictions"], + labels=inputs["labels"], + weights=inputs["weights"], + **kwargs, + ) + actual_precision = precision_metric.compute() + + for task_id, task in enumerate(task_list): + cur_actual_precision = actual_precision[ + f"precision_session_level-{task.name}|lifetime_precision_session_level" + ] + cur_expected_precision = expected_precision[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_precision, + cur_expected_precision, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {cur_actual_precision}, Expected: {cur_expected_precision}", + ) + return precision_metric + + def test_precision_session_with_ranked_labels(self) -> None: + test_data = generate_model_output_test1() + try: + self._test_precision_session_helper( + run_ranking_of_labels=True, precision_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_precision_session_with_bool_labels(self) -> None: + test_data = generate_model_output_test2() + try: + self._test_precision_session_helper( + run_ranking_of_labels=False, precision_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_precision_session_with_no_positive_examples(self) -> None: + test_data = generate_model_output_with_no_positive_examples() + try: + self._test_precision_session_helper( + run_ranking_of_labels=False, precision_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_precision_session_with_no_positive_predictions(self) -> None: + test_data = generate_model_output_with_no_positive_predictions() + try: + self._test_precision_session_helper( + run_ranking_of_labels=False, precision_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_error_messages(self) -> None: + task_info1 = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + ) + + task_info2 = RecTaskInfo( + name="Task2", + label_name="label2", + prediction_name="prediction2", + weight_name="weight2", + session_metric_def=SessionMetricDef(session_var_name="session"), + ) + + error_message1 = "Please, specify the session metric definition" + with self.assertRaisesRegex(RecMetricException, error_message1): + _ = PrecisionSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info1], + ) + error_message2 = "Please, specify the top threshold" + with self.assertRaisesRegex(RecMetricException, error_message2): + _ = PrecisionSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info2], + ) + + def test_compute_mode_exception(self) -> None: + task_info = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + ) + with self.assertRaisesRegex( + RecMetricException, + "Fused computation is not supported for precision session-level metrics", + ): + PrecisionSessionMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[task_info], + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + with self.assertRaisesRegex( + RecMetricException, + "Fused computation is not supported for precision session-level metrics", + ): + PrecisionSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info], + compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ) + + def test_tasks_input_propagation(self) -> None: + task_info1 = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + session_metric_def=SessionMetricDef( + session_var_name="session1", + top_threshold=1, + run_ranking_of_labels=True, + ), + ) + + task_info2 = RecTaskInfo( + name="Task2", + label_name="label2", + prediction_name="prediction2", + weight_name="weight2", + session_metric_def=SessionMetricDef( + session_var_name="session2", + top_threshold=2, + run_ranking_of_labels=False, + ), + ) + + precision_metric = PrecisionSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info1, task_info2], + ) + + # metrics checks + self.assertSetEqual( + precision_metric.get_required_inputs(), {"session1", "session2"} + ) + self.assertTrue(len(precision_metric._tasks) == 2) + self.assertTrue(precision_metric._tasks[0] == task_info1) + self.assertTrue(precision_metric._tasks[1] == task_info2) + + # metrics_computations checks + self.assertTrue(precision_metric._metrics_computations[0]._my_rank == 5) + self.assertTrue(precision_metric._metrics_computations[1]._my_rank == 5) + self.assertTrue(precision_metric._metrics_computations[0]._batch_size == 100) + self.assertTrue(precision_metric._metrics_computations[1]._batch_size == 100) + + self.assertTrue(precision_metric._metrics_computations[0].top_threshold == 1) + self.assertTrue(precision_metric._metrics_computations[1].top_threshold == 2) + self.assertTrue( + precision_metric._metrics_computations[0].session_var_name == "session1" + ) + self.assertTrue( + precision_metric._metrics_computations[1].session_var_name == "session2" + ) + self.assertTrue(precision_metric._metrics_computations[0].run_ranking_of_labels) + self.assertTrue( + precision_metric._metrics_computations[1].run_ranking_of_labels is False + ) diff --git a/torchrec/metrics/tests/test_rauc.py b/torchrec/metrics/tests/test_rauc.py new file mode 100644 index 000000000..be9a7dd4b --- /dev/null +++ b/torchrec/metrics/tests/test_rauc.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, List, Optional, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.rauc import RAUCMetric +from torchrec.metrics.rec_metric import ( + RecComputeMode, + RecMetric, + RecMetricException, + RecTaskInfo, +) +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +def compute_rauc( + predictions: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + + n = len(predictions) + cnt = 0.0 + for i in range(n - 1): + for j in range(i + 1, n): + if (labels[i] - labels[j]) * (predictions[i] - predictions[j]) >= 0: + cnt += 1 + + return torch.tensor(cnt / (n * (n - 1) / 2)) + + +class TestRAUCMetric(TestMetric): + def __init__( + self, + world_size: int, + rec_tasks: List[RecTaskInfo], + ) -> None: + super().__init__( + world_size, + rec_tasks, + compute_lifetime_metric=False, + local_compute_lifetime_metric=False, + ) + + @staticmethod + def _aggregate( + states: Dict[str, torch.Tensor], new_states: Dict[str, torch.Tensor] + ) -> None: + for k, v in new_states.items(): + if k not in states: + states[k] = v.float().detach().clone() + else: + states[k] = torch.cat([states[k], v.float()]) + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return { + "predictions": predictions, + "weights": weights, + "labels": labels, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_rauc(states["predictions"], states["labels"], states["weights"]) + + +WORLD_SIZE = 4 + + +class RAUCMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = RAUCMetric + task_name: str = "rauc" + + def test_unfused_rauc(self) -> None: + rec_metric_value_test_launcher( + target_clazz=RAUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestRAUCMetric, + metric_name=RAUCMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_fused_rauc(self) -> None: + rec_metric_value_test_launcher( + target_clazz=RAUCMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestRAUCMetric, + metric_name=RAUCMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class RAUCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = RAUCMetric + task_name: str = "rauc" + + def test_sync_rauc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=RAUCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestRAUCMetric, + metric_name=RAUCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +class RAUCMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of RAUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.rauc = RAUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_rauc_perfect(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(100)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(100)] * 2] + ) + self.weights["DefaultTask"] = torch.Tensor([[1] * 50 + [0] * 100 + [1] * 50]) + + expected_rauc = torch.tensor([1], dtype=torch.float) + self.rauc.update(**self.batches) + actual_rauc = self.rauc.compute()["rauc-DefaultTask|window_rauc"] + assert torch.allclose(expected_rauc, actual_rauc) + + def test_calc_rauc_zero(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(100)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor( + [[-0.0001 * x for x in range(100)] * 2] + ) + self.weights["DefaultTask"] = torch.Tensor([[0] * 50 + [1] * 100 + [0] * 50]) + + expected_rauc = torch.tensor([0], dtype=torch.float) + self.rauc.update(**self.batches) + actual_rauc = self.rauc.compute()["rauc-DefaultTask|window_rauc"] + assert torch.allclose(expected_rauc, actual_rauc) + + def test_calc_rauc_random(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor([[1, 2, 3, 4]]) + self.labels["DefaultTask"] = torch.Tensor([[2, 1, 4, 3]]) + self.weights["DefaultTask"] = torch.Tensor([[1, 1, 1, 1]]) + + expected_rauc = torch.tensor([2.0 / 3], dtype=torch.float) + self.rauc.update(**self.batches) + actual_rauc = self.rauc.compute()["rauc-DefaultTask|window_rauc"] + assert torch.allclose(expected_rauc, actual_rauc) + + def test_window_size_rauc(self) -> None: + # for determinisitc batches + torch.manual_seed(0) + + rauc = RAUCMetric( + world_size=1, + my_rank=0, + batch_size=5, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + # init states, so we expect 3 (state tensors) * 4 bytes (float) + self.assertEqual(sum(rauc.get_memory_usage().values()), 12) + + # bs = 5 + self.labels["DefaultTask"] = torch.rand(5) + self.predictions["DefaultTask"] = torch.rand(5) + self.weights["DefaultTask"] = torch.rand(5) + + for _ in range(1000): + rauc.update(**self.batches) + + # check memory, window size is 100, so we have upperbound of memory to expect + # so with a 100 window size / tensors of size 5 = 20 tensors (per state) * 3 states * 20 bytes per tensor of size 5 = 1200 bytes + self.assertEqual(sum(rauc.get_memory_usage().values()), 1200) + # with bs 5, we expect 20 tensors per state, so 60 tensors + self.assertEqual(len(rauc.get_memory_usage().values()), 60) + + assert torch.allclose( + rauc.compute()["rauc-DefaultTask|window_rauc"], + torch.tensor([0.5152], dtype=torch.float), + atol=1e-4, + ) + + # test rauc memory usage with window size equal to incoming batch + rauc = RAUCMetric( + world_size=1, + my_rank=0, + batch_size=100, + window_size=100, + tasks=[DefaultTaskInfo], + ) + + self.labels["DefaultTask"] = torch.rand(100) + self.predictions["DefaultTask"] = torch.rand(100) + self.weights["DefaultTask"] = torch.rand(100) + + for _ in range(10): + rauc.update(**self.batches) + + # passing in batch size == window size, we expect for each state just one tensor of size 400, sum to 1200 as previous + self.assertEqual(sum(rauc.get_memory_usage().values()), 1200) + self.assertEqual(len(rauc.get_memory_usage().values()), 3) + + assert torch.allclose( + rauc.compute()["rauc-DefaultTask|window_rauc"], + torch.tensor([0.5508], dtype=torch.float), + atol=1e-4, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_rauc": torch.tensor([0.3333]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 1, 0, 0, 1]), + "expected_rauc": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_rauc": torch.tensor([0.1667]), + }, + # all_scores_the_same + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[1] * 6]), + "grouping_keys": torch.tensor([1, 1, 1, 0, 0, 0]), + "expected_rauc": torch.tensor([1.0]), + }, + # one_class_in_input + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[1] * 5]), + "grouping_keys": torch.tensor([1, 0, 0, 1, 0]), + "expected_rauc": torch.tensor([1.0]), + }, + # one_group + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([1, 1, 1, 1, 1]), + "expected_rauc": torch.tensor([0.6]), + }, + # two tasks + { + "labels": torch.tensor([[1, 0, 0, 1, 0], [1, 1, 1, 1, 0]]), + "predictions": torch.tensor( + [ + [0.2281, 0.1051, 0.4885, 0.7740, 0.3097], + [0.4658, 0.3445, 0.6048, 0.6587, 0.5088], + ] + ), + "weights": torch.tensor( + [ + [0.6334, 0.6937, 0.6631, 0.5078, 0.3570], + [0.2637, 0.2479, 0.2697, 0.6500, 0.7583], + ] + ), + "grouping_keys": torch.tensor([0, 1, 0, 0, 1]), + "expected_rauc": torch.tensor([0.8333, 0.5]), + }, + ] + + +class GroupedRAUCValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of RAUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_grouped_rauc_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_rauc: torch.Tensor, + grouping_keys: Optional[torch.Tensor] = None, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + if grouping_keys is not None: + inputs["required_inputs"] = {"grouping_keys": grouping_keys} + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + rauc = RAUCMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + grouped_rauc=True, + ) + rauc.update(**inputs) + actual_rauc = rauc.compute() + + for task_id, task in enumerate(task_list): + cur_actual_rauc = actual_rauc[f"rauc-{task.name}|window_grouped_rauc"] + cur_expected_rauc = expected_rauc[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_rauc, + cur_expected_rauc, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_rauc}, Expected: {cur_expected_rauc}", + ) + + def test_grouped_rauc(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_grouped_rauc_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + def test_misconfigured_grouped_rauc(self) -> None: + with self.assertRaises(RecMetricException): + self._test_grouped_rauc_helper( + **{ + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + # no provided grouping_keys + "expected_rauc": torch.tensor([0.2419]), + }, + ) + + def test_required_input_for_grouped_Rauc(self) -> None: + rauc = RAUCMetric( + world_size=1, + my_rank=0, + batch_size=1, + tasks=[ + RecTaskInfo( + name="Task:0", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + ], + # pyre-ignore + grouped_rauc=True, + ) + + self.assertIn("grouping_keys", rauc.get_required_inputs()) diff --git a/torchrec/metrics/tests/test_recall.py b/torchrec/metrics/tests/test_recall.py new file mode 100644 index 000000000..d09faf464 --- /dev/null +++ b/torchrec/metrics/tests/test_recall.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.recall import compute_recall, RecallMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestRecallMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + true_pos_sum = torch.sum(weights * ((predictions >= 0.5) * labels)) + false_neg_sum = torch.sum(weights * ((predictions <= 0.5) * (labels))) + return { + "true_pos_sum": true_pos_sum, + "false_neg_sum": false_neg_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_recall( + states["true_pos_sum"], + states["false_neg_sum"], + ) + + +class RecallMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = RecallMetric + task_name: str = "recall" + + def test_recall_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=RecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestRecallMetric, + metric_name=RecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_recall_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=RecallMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestRecallMetric, + metric_name=RecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_recall_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=RecallMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestRecallMetric, + metric_name=RecallMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class RecallMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of recall in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.recall = RecallMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_acc_perfect(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[1] * 5000 + [0] * 10000 + [1] * 5000] + ) + + expected_recall = torch.tensor([1], dtype=torch.double) + self.recall.update(**self.batches) + actual_recall = self.recall.compute()["recall-DefaultTask|window_recall"] + torch.allclose(expected_recall, actual_recall) + + def test_calc_acc_zero(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.Tensor( + [[0] * 5000 + [1] * 10000 + [0] * 5000] + ) + + expected_recall = torch.tensor([0], dtype=torch.double) + self.recall.update(**self.batches) + actual_recall = self.recall.compute()["recall-DefaultTask|window_recall"] + torch.allclose(expected_recall, actual_recall) + + def test_calc_recall_balanced(self) -> None: + self.predictions["DefaultTask"] = torch.Tensor( + [[0.0001 * x for x in range(10000)] * 2] + ) + self.labels["DefaultTask"] = torch.Tensor([[0] * 10000 + [1] * 10000]) + self.weights["DefaultTask"] = torch.ones([1, 20000]) + + expected_recall = torch.tensor([0.5], dtype=torch.double) + self.recall.update(**self.batches) + actual_recall = self.recall.compute()["recall-DefaultTask|window_recall"] + torch.allclose(expected_recall, actual_recall) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Union[float, torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.3, 0.2, 0.5, 0.8, 0.7]]), + "threshold": 0.6, + "expected_recall": torch.tensor([0.7 / 1.8]), + }, + # perfect_condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.6, + "expected_recall": torch.tensor([1.0]), + }, + # inverse_prediction + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0, 1, 1, 0, 0]]), + "weights": torch.tensor([[1] * 5]), + "threshold": 0.1, + "expected_recall": torch.tensor([0.0]), + }, + ] + + +class ThresholdValueTest(unittest.TestCase): + """This set of tests verify the computation logic of recall with a modified threshold + in several cases that we know the computation results. + """ + + @no_grad() + def _test_recall_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_recall: torch.Tensor, + threshold: float, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + recall = RecallMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + threshold=threshold, # threshold is one of the kwargs + ) + recall.update(**inputs) + actual_recall = recall.compute() + + for task_id, task in enumerate(task_list): + cur_actual_recall = actual_recall[f"recall-{task.name}|window_recall"] + cur_expected_recall = expected_recall[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_recall, + cur_expected_recall, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_recall}, Expected: {cur_expected_recall}", + ) + + def test_recall(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_recall_helper( + **inputs # pyre-ignore, surpressing a type hint error + ) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + +class RecallGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = RecallMetric + task_name: str = "recall" + + def test_sync_recall(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=RecallMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestRecallMetric, + metric_name=RecallGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_recall_session.py b/torchrec/metrics/tests/test_recall_session.py new file mode 100644 index 000000000..a21a180de --- /dev/null +++ b/torchrec/metrics/tests/test_recall_session.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Optional, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import ( + RecComputeMode, + RecTaskInfo, + SessionMetricDef, +) +from torchrec.metrics.rec_metric import RecMetricException + +from torchrec.metrics.recall_session import RecallSessionMetric + + +def generate_model_output_test1() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor( + [[0.9, 0.1, 0.2, 0.3, 0.9, 0.9, 0.0, 0.9, 0.1, 0.4, 0.9, 0.1]] + ), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_recall": torch.tensor([0.5]), + } + + +def generate_model_output_test2() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor( + [[1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]] + ), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_recall": torch.tensor([0.5]), + } + + +def generate_model_output_with_no_positive_examples() -> ( + Dict[str, torch._tensor.Tensor] +): + return { + "predictions": torch.tensor( + [[1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8, 1.0, 0.0, 0.51, 0.8]] + ), + "session": torch.tensor([[1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1]]), + "labels": torch.tensor([[0.0] * 12]), + "weights": torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] + ), + "expected_recall": torch.tensor([float("nan")]), + } + + +class RecallSessionValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of Recall in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_recall_session_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + session: torch.Tensor, + expected_recall: torch.Tensor, + run_ranking_of_labels: bool = False, + recall_metric: Optional[RecallSessionMetric] = None, + ) -> RecallSessionMetric: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + session_metric_def=SessionMetricDef( + session_var_name="session", + top_threshold=3, + run_ranking_of_labels=run_ranking_of_labels, + ), + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + kwargs = {"required_inputs": {"session": session}} + + if recall_metric is None: + recall_metric = RecallSessionMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + recall_metric.update( + predictions=inputs["predictions"], + labels=inputs["labels"], + weights=inputs["weights"], + **kwargs, + ) + actual_recall = recall_metric.compute() + + for task_id, task in enumerate(task_list): + cur_actual_recall = actual_recall[ + f"recall_session_level-{task.name}|lifetime_recall_session_level" + ] + cur_expected_recall = expected_recall[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_recall, + cur_expected_recall, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {cur_actual_recall}, Expected: {cur_expected_recall}", + ) + return recall_metric + + def test_recall_session_with_ranked_labels(self) -> None: + test_data = generate_model_output_test1() + try: + self._test_recall_session_helper( + run_ranking_of_labels=True, recall_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_recall_session_with_bool_labels(self) -> None: + test_data = generate_model_output_test2() + try: + self._test_recall_session_helper( + run_ranking_of_labels=False, recall_metric=None, **test_data + ) + except AssertionError: + print("Assertion error caught with data set ", test_data) + raise + + def test_recall_session_with_no_positive_examples(self) -> None: + + # if we pass a batch with no positive examples, we should get NaN + test_data_with_no_positive_examples = ( + generate_model_output_with_no_positive_examples() + ) + try: + recall_metric = self._test_recall_session_helper( + run_ranking_of_labels=False, + recall_metric=None, + **test_data_with_no_positive_examples, + ) + except AssertionError: + print( + "Assertion error caught with data set ", + test_data_with_no_positive_examples, + ) + raise + # once we get a batch with positive examples, we should NOT get NaN + test_data_with_pos_examples = generate_model_output_test2() + try: + recall_metric = self._test_recall_session_helper( + run_ranking_of_labels=False, + recall_metric=recall_metric, + **test_data_with_pos_examples, + ) + except AssertionError: + print("Assertion error caught with data set ", test_data_with_pos_examples) + raise + # if we get a batch with no positive examples and metric states are non-zero, they should not change + test_data_with_no_positive_examples["expected_recall"] = torch.tensor([0.5]) + try: + self._test_recall_session_helper( + run_ranking_of_labels=False, + recall_metric=recall_metric, + **test_data_with_no_positive_examples, + ) + except AssertionError: + print( + "Assertion error caught with data set ", + test_data_with_no_positive_examples, + ) + raise + + def test_error_messages(self) -> None: + + task_info1 = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + ) + + task_info2 = RecTaskInfo( + name="Task2", + label_name="label2", + prediction_name="prediction2", + weight_name="weight2", + session_metric_def=SessionMetricDef(session_var_name="session"), + ) + + error_message1 = "Please, specify the session metric definition" + with self.assertRaisesRegex(RecMetricException, error_message1): + _ = RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info1], + ) + error_message2 = "Please, specify the top threshold" + with self.assertRaisesRegex(RecMetricException, error_message2): + _ = RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info2], + ) + + def test_compute_mode_exception(self) -> None: + task_info = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + ) + with self.assertRaisesRegex( + RecMetricException, + "Fused computation is not supported for recall session-level metrics", + ): + RecallSessionMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[task_info], + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + with self.assertRaisesRegex( + RecMetricException, + "Fused computation is not supported for recall session-level metrics", + ): + RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info], + compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ) + + def test_tasks_input_propagation(self) -> None: + task_info1 = RecTaskInfo( + name="Task1", + label_name="label1", + prediction_name="prediction1", + weight_name="weight1", + session_metric_def=SessionMetricDef( + session_var_name="session1", + top_threshold=1, + run_ranking_of_labels=True, + ), + ) + + task_info2 = RecTaskInfo( + name="Task2", + label_name="label2", + prediction_name="prediction2", + weight_name="weight2", + session_metric_def=SessionMetricDef( + session_var_name="session2", + top_threshold=2, + run_ranking_of_labels=False, + ), + ) + + recall_metric = RecallSessionMetric( + world_size=1, + my_rank=5, + batch_size=100, + tasks=[task_info1, task_info2], + ) + + # metrics checks + self.assertSetEqual( + recall_metric.get_required_inputs(), {"session1", "session2"} + ) + self.assertTrue(len(recall_metric._tasks) == 2) + self.assertTrue(recall_metric._tasks[0] == task_info1) + self.assertTrue(recall_metric._tasks[1] == task_info2) + + # metrics_computations checks + self.assertTrue(recall_metric._metrics_computations[0]._my_rank == 5) + self.assertTrue(recall_metric._metrics_computations[1]._my_rank == 5) + self.assertTrue(recall_metric._metrics_computations[0]._batch_size == 100) + self.assertTrue(recall_metric._metrics_computations[1]._batch_size == 100) + + self.assertTrue(recall_metric._metrics_computations[0].top_threshold == 1) + self.assertTrue(recall_metric._metrics_computations[1].top_threshold == 2) + self.assertTrue( + recall_metric._metrics_computations[0].session_var_name == "session1" + ) + self.assertTrue( + recall_metric._metrics_computations[1].session_var_name == "session2" + ) + self.assertTrue(recall_metric._metrics_computations[0].run_ranking_of_labels) + self.assertTrue( + recall_metric._metrics_computations[1].run_ranking_of_labels is False + ) diff --git a/torchrec/metrics/tests/test_recmetric.py b/torchrec/metrics/tests/test_recmetric.py index 2a3ddd4d3..9d423d0af 100644 --- a/torchrec/metrics/tests/test_recmetric.py +++ b/torchrec/metrics/tests/test_recmetric.py @@ -5,22 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch -from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.metrics_config import DefaultTaskInfo, RecTaskInfo from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.mse import MSEMetric from torchrec.metrics.ne import NEMetric -from torchrec.metrics.rec_metric import RecComputeMode +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import gen_test_batch, gen_test_tasks +_CUDA_UNAVAILABLE: bool = not torch.cuda.is_available() + + class RecMetricTest(unittest.TestCase): def setUp(self) -> None: # Create testing labels, predictions and weights model_output = gen_test_batch(128) - self.labels, self.predictions, self.weights = parse_task_model_outputs( + self.labels, self.predictions, self.weights, _ = parse_task_model_outputs( [DefaultTaskInfo], model_output ) @@ -28,7 +33,7 @@ def test_optional_weights(self) -> None: ne1 = NEMetric( world_size=1, my_rank=0, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -37,7 +42,7 @@ def test_optional_weights(self) -> None: ne2 = NEMetric( world_size=1, my_rank=1, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -69,7 +74,7 @@ def test_zero_weights(self) -> None: mse = MSEMetric( world_size=1, my_rank=0, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -106,11 +111,7 @@ def test_zero_weights(self) -> None: self.assertGreater(mse_computation.weighted_num_samples, torch.tensor(0.0)) res = mse.compute() - # pyre-fixme[6]: For 2nd param expected `SupportsDunderLT[Variable[_T]]` but - # got `Tensor`. self.assertGreater(res["mse-DefaultTask|lifetime_mse"], torch.tensor(0.0)) - # pyre-fixme[6]: For 2nd param expected `SupportsDunderLT[Variable[_T]]` but - # got `Tensor`. self.assertGreater(res["mse-DefaultTask|lifetime_rmse"], torch.tensor(0.0)) # Test if weights = 0 for one task of an update @@ -121,12 +122,12 @@ def test_zero_weights(self) -> None: label_name=task.label_name, prediction_name=task.prediction_name, weight_name=task.weight_name, - batch_size=128, + batch_size=64, ) for task in tasks ] model_output = {k: v for d in _model_output for k, v in d.items()} - labels, predictions, weights = parse_task_model_outputs(tasks, model_output) + labels, predictions, weights, _ = parse_task_model_outputs(tasks, model_output) partial_zero_weights = { "t1": torch.zeros_like(weights["t1"]), "t2": weights["t2"], @@ -135,7 +136,7 @@ def test_zero_weights(self) -> None: ne = NEMetric( world_size=1, my_rank=0, - batch_size=128, + batch_size=64, tasks=tasks, compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -160,8 +161,6 @@ def test_zero_weights(self) -> None: res = ne.compute() self.assertEqual(res["ne-t1|lifetime_ne"], torch.tensor(0.0)) - # pyre-fixme[6]: For 2nd param expected `SupportsDunderLT[Variable[_T]]` but - # got `Tensor`. self.assertGreater(res["ne-t2|lifetime_ne"], torch.tensor(0.0)) ne.update( @@ -177,8 +176,6 @@ def test_zero_weights(self) -> None: self.assertGreater(ne_computation[0].weighted_num_samples, torch.tensor(0.0)) res = ne.compute() - # pyre-fixme[6]: For 2nd param expected `SupportsDunderLT[Variable[_T]]` but - # got `Tensor`. self.assertGreater(res["ne-t1|lifetime_ne"], torch.tensor(0.0)) def test_compute(self) -> None: @@ -186,7 +183,7 @@ def test_compute(self) -> None: ne = NEMetric( world_size=1, my_rank=0, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -205,7 +202,7 @@ def test_compute(self) -> None: ne = NEMetric( world_size=1, my_rank=1, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -223,7 +220,7 @@ def test_compute(self) -> None: ne = NEMetric( world_size=1, my_rank=1, - batch_size=128, + batch_size=64, tasks=[DefaultTaskInfo], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, @@ -238,3 +235,61 @@ def test_compute(self) -> None: res = ne.compute() self.assertIn("ne-DefaultTask|lifetime_ne", res) self.assertIn("ne-DefaultTask|window_ne", res) + + def test_invalid_window_size(self) -> None: + with self.assertRaises(ValueError): + RecMetric( + world_size=8, + my_rank=0, + window_size=50, + batch_size=10, + tasks=[DefaultTaskInfo], + ) + + def test_reset(self) -> None: + ne = NEMetric( + world_size=1, + my_rank=0, + batch_size=64, + tasks=[DefaultTaskInfo], + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=1000, + fused_update_limit=0, + ) + ne.update( + predictions=self.predictions, + labels=self.labels, + weights=self.weights, + ) + ne = ne._metrics_computations[0] + window_buffer = ne._batch_window_buffers["window_cross_entropy_sum"].buffers + self.assertTrue(len(window_buffer) > 0) + ne.reset() + window_buffer = ne._batch_window_buffers["window_cross_entropy_sum"].buffers + self.assertEqual(len(window_buffer), 0) + + @unittest.skipIf(_CUDA_UNAVAILABLE, "Test needs to run on GPU") + def test_parse_task_model_outputs_ndcg(self) -> None: + _, _, _, required_inputs = parse_task_model_outputs( + tasks=[ + RecTaskInfo( + name="ndcg_example", + ), + ], + # pyre-fixme[6]: for argument model_out, expected Dict[str, Tensor] but + # got Dict[str, Union[List[str], Tensor]] + model_out={ + "label": torch.tensor( + [0.0, 1.0, 0.0, 1.0], device=torch.device("cuda:0") + ), + "weight": torch.tensor( + [1.0, 1.0, 1.0, 1.0], device=torch.device("cuda:0") + ), + "prediction": torch.tensor( + [0.0, 1.0, 0.0, 1.0], device=torch.device("cuda:0") + ), + "session_id": ["1", "1", "2", "2"], + }, + required_inputs_list=["session_id"], + ) + self.assertEqual(required_inputs["session_id"].device, torch.device("cuda:0")) diff --git a/torchrec/metrics/tests/test_scalar.py b/torchrec/metrics/tests/test_scalar.py new file mode 100644 index 000000000..4aa3bf568 --- /dev/null +++ b/torchrec/metrics/tests/test_scalar.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.scalar import ScalarMetric + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +class ScalarMetricTest(unittest.TestCase): + def setUp(self) -> None: + self.scalar = ScalarMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + ) + + def test_scalar(self) -> None: + """ + Test scalar metric passes through each tensor as is + """ + metric_to_log = torch.tensor([0.1]) + + self.scalar.update( + labels={DefaultTaskInfo.name: metric_to_log}, + predictions={DefaultTaskInfo.name: metric_to_log}, + weights={DefaultTaskInfo.name: metric_to_log}, + ) + metric = self.scalar.compute() + actual_metric = metric[f"scalar-{DefaultTaskInfo.name}|lifetime_scalar"] + + torch.testing.assert_close( + actual_metric, + metric_to_log, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {metric_to_log}", + ) + + # Pass through second tensor with different value + # check we get the value back with no averaging or any differences + + metric_to_log = torch.tensor([0.9]) + + self.scalar.update( + labels={DefaultTaskInfo.name: metric_to_log}, + predictions={DefaultTaskInfo.name: metric_to_log}, + weights={DefaultTaskInfo.name: metric_to_log}, + ) + metric = self.scalar.compute() + actual_metric = metric[f"scalar-{DefaultTaskInfo.name}|lifetime_scalar"] + + torch.testing.assert_close( + actual_metric, + metric_to_log, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {metric_to_log}", + ) + + def test_scalar_window(self) -> None: + """ + Test windowing of scalar metric gives average of previously reported values. + """ + metric_to_log = torch.tensor([0.1]) + + self.scalar.update( + labels={DefaultTaskInfo.name: metric_to_log}, + predictions={DefaultTaskInfo.name: metric_to_log}, + weights={DefaultTaskInfo.name: metric_to_log}, + ) + + metric_to_log = torch.tensor([0.9]) + + self.scalar.update( + labels={DefaultTaskInfo.name: metric_to_log}, + predictions={DefaultTaskInfo.name: metric_to_log}, + weights={DefaultTaskInfo.name: metric_to_log}, + ) + + metric = self.scalar.compute() + + actual_window_metric = metric[f"scalar-{DefaultTaskInfo.name}|window_scalar"] + + expected_window_metric = torch.tensor([0.5]) + + torch.testing.assert_close( + actual_window_metric, + expected_window_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_window_metric}, Expected: {expected_window_metric}", + ) diff --git a/torchrec/metrics/tests/test_segmented_ne.py b/torchrec/metrics/tests/test_segmented_ne.py new file mode 100644 index 000000000..507a7cc8f --- /dev/null +++ b/torchrec/metrics/tests/test_segmented_ne.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Any, Dict, Iterable, Union + +import torch +from torch import no_grad +from torchrec.metrics.rec_metric import RecTaskInfo +from torchrec.metrics.segmented_ne import SegmentedNEMetric + + +class SegementedNEValueTest(unittest.TestCase): + """ + This set of tests verify the computation logic of AUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @no_grad() + def _test_segemented_ne_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_ne: torch.Tensor, + grouping_keys: torch.Tensor, + grouping_key_tensor_name: str = "grouping_keys", + cast_keys_to_int: bool = False, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + if grouping_keys is not None: + inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys} + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + ne = SegmentedNEMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + # pyre-ignore + num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1), + # pyre-ignore + grouping_keys=grouping_key_tensor_name, + # pyre-ignore + cast_keys_to_int=cast_keys_to_int, + ) + ne.update(**inputs) + actual_ne = ne.compute() + + for task_id, task in enumerate(task_list): + for label in [0, 1]: + cur_actual_ne = actual_ne[ + f"segmented_ne-{task.name}|lifetime_segmented_ne_{label}" + ] + cur_expected_ne = expected_ne[task_id][label] + + torch.testing.assert_close( + cur_actual_ne, + cur_expected_ne, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {cur_actual_ne}, Expected: {cur_expected_ne}", + ) + + def test_grouped_ne(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + self._test_segemented_ne_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]: + return [ + # base condition + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + }, + # one sided, edge case 1s + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([1, 1, 1, 1, 1]), + "expected_ne": torch.tensor([[torch.nan, 1.3936]]), + }, + # one sided, edge case 0s + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 0, 0, 0, 0]), + "expected_ne": torch.tensor([[1.3936, torch.nan]]), + }, + # three labels, + { + "labels": torch.tensor([[1, 0, 0, 1, 1, 0]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9, 0.4]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75, 0.4]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 2, 2]), + "expected_ne": torch.tensor([[3.1615, 1.8311, 0.3814]]), + }, + # two tasks + { + "labels": torch.tensor([[1, 0, 0, 1, 1], [1, 0, 0, 1, 1]]), + "predictions": torch.tensor( + [ + [0.2, 0.6, 0.8, 0.4, 0.9], + [0.6, 0.2, 0.4, 0.8, 0.9], + ] + ), + "weights": torch.tensor( + [ + [0.13, 0.2, 0.5, 0.8, 0.75], + [0.13, 0.2, 0.5, 0.8, 0.75], + ] + ), + "grouping_keys": torch.tensor( + [0, 1, 0, 1, 1] + ), # for this case, both tasks have same groupings + "expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]), + }, + # Custom grouping key tensor name + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + }, + # Cast grouping keys to int32 + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + "cast_keys_to_int": True, + }, + ] diff --git a/torchrec/metrics/tests/test_serving_calibration.py b/torchrec/metrics/tests/test_serving_calibration.py new file mode 100644 index 000000000..810a69bfb --- /dev/null +++ b/torchrec/metrics/tests/test_serving_calibration.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.serving_calibration import ServingCalibrationMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + +WORLD_SIZE = 4 + + +class TestServingCalibrationMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + calibration_num = torch.sum(predictions * weights) + calibration_denom = torch.sum(labels * weights) + num_samples = torch.count_nonzero(weights) + return { + "calibration_num": calibration_num, + "calibration_denom": calibration_denom, + "num_samples": num_samples, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return torch.where( + states["calibration_denom"] <= 0.0, + 0.0, + states["calibration_num"] / states["calibration_denom"], + ).double() + + +class ServingCalibrationMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = ServingCalibrationMetric + task_name: str = "calibration" + + def test_calibration_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingCalibrationMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestServingCalibrationMetric, + metric_name=ServingCalibrationMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_calibration_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingCalibrationMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestServingCalibrationMetric, + metric_name=ServingCalibrationMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_calibration_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingCalibrationMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestServingCalibrationMetric, + metric_name=ServingCalibrationMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +# TODO - Serving Calibration uses Calibration naming inconsistently +class ServingCalibrationGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = ServingCalibrationMetric + task_name: str = "serving_calibration" + + def test_sync_serving_calibration(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=ServingCalibrationMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestServingCalibrationMetric, + metric_name=ServingCalibrationGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_serving_ne.py b/torchrec/metrics/tests/test_serving_ne.py new file mode 100644 index 000000000..888bc1278 --- /dev/null +++ b/torchrec/metrics/tests/test_serving_ne.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from typing import Dict, Type + +import torch + +from torchrec.metrics.ne import compute_cross_entropy, compute_ne +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.serving_ne import ServingNEMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_value_test_launcher, + TestMetric, +) + + +WORLD_SIZE = 2 + + +class TestNEMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + weighted_num_samples = torch.sum(weights) + pos_labels = torch.sum(weights * labels) + neg_labels = torch.sum(weights * (1.0 - labels)) + return { + "cross_entropy_sum": cross_entropy_sum, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + "num_samples": torch.tensor(labels.size()).long(), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return compute_ne( + states["cross_entropy_sum"], + states["weighted_num_samples"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + eta=TestNEMetric.eta, + ) + + +class ServingNEMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = ServingNEMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "ne" + + def test_ne_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=ServingNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_ne_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=ServingNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_ne_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=ServingNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_ne_update_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=ServingNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=ServingNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=ServingNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=ServingNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) diff --git a/torchrec/metrics/tests/test_throughput.py b/torchrec/metrics/tests/test_throughput.py index 56fe8a33d..e8acaa537 100644 --- a/torchrec/metrics/tests/test_throughput.py +++ b/torchrec/metrics/tests/test_throughput.py @@ -5,11 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + # pyre-ignore-all-errors[56] import unittest +from collections import OrderedDict +from typing import Any, Dict from unittest.mock import Mock, patch +import torch + +from torchrec.metrics.metrics_config import BatchSizeStage + from torchrec.metrics.throughput import ThroughputMetric @@ -28,7 +36,11 @@ def test_no_batches(self, time_mock: Mock) -> None: batch_size=self.batch_size, world_size=self.world_size, window_seconds=100 ) self.assertEqual( - throughput_metric.compute(), {"throughput-throughput|total_examples": 0} + throughput_metric.compute(), + { + "throughput-throughput|total_examples": 0, + "throughput-throughput|attempt_examples": 0, + }, ) @patch(THROUGHPUT_PATH + ".time.monotonic") @@ -40,7 +52,12 @@ def test_one_batch(self, time_mock: Mock) -> None: throughput_metric.update() self.assertEqual( throughput_metric.compute(), - {"throughput-throughput|total_examples": self.batch_size * self.world_size}, + { + "throughput-throughput|total_examples": self.batch_size + * self.world_size, + "throughput-throughput|attempt_examples": self.batch_size + * self.world_size, + }, ) @patch(THROUGHPUT_PATH + ".time.monotonic") @@ -69,7 +86,11 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None: total_examples = self.world_size * self.batch_size * (i + 1) if i < warmup_steps: self.assertEqual( - ret, {"throughput-throughput|total_examples": total_examples} + ret, + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + }, ) continue @@ -98,6 +119,13 @@ def _test_throughput(self, time_mock: Mock, warmup_steps: int) -> None: self.assertEqual( ret["throughput-throughput|total_examples"], total_examples ) + # only one attempt so attempt examples and throughput are the same as total/lifetime + self.assertEqual( + ret["throughput-throughput|attempt_examples"], total_examples + ) + self.assertEqual( + ret["throughput-throughput|attempt_throughput"], lifetime_throughput + ) def test_throughput_warmup_steps_0(self) -> None: with self.assertRaises(ValueError): @@ -136,5 +164,267 @@ def test_warmup_checkpointing(self) -> None: * self.world_size, ) + self.assertEqual( + throughput_metric.attempt_warmup_examples.item(), + warmup_steps * self.batch_size * self.world_size, + ) + self.assertEqual( + throughput_metric.attempt_examples.item(), + (warmup_steps + extra_steps) * self.batch_size * self.world_size, + ) # Mimic trainer crashing and loading a checkpoint throughput_metric._steps = 0 + throughput_metric.attempt_examples = torch.tensor(0, dtype=torch.long) + throughput_metric.attempt_warmup_examples = torch.tensor( + 0, dtype=torch.long + ) + throughput_metric.attempt_time_lapse_after_warmup = torch.tensor( + 0, dtype=torch.double + ) + + @patch(THROUGHPUT_PATH + ".time.monotonic") + def test_batch_size_schedule(self, time_mock: Mock) -> None: + batch_size_stages = [BatchSizeStage(256, 1), BatchSizeStage(512, None)] + time_mock.return_value = 1 + throughput_metric = ThroughputMetric( + batch_size=self.batch_size, + world_size=self.world_size, + window_seconds=100, + batch_size_stages=batch_size_stages, + ) + + total_examples = 0 + throughput_metric.update() + total_examples += batch_size_stages[0].batch_size * self.world_size + self.assertEqual( + throughput_metric.compute(), + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + "throughput-throughput|batch_size": 256, + }, + ) + + throughput_metric.update() + total_examples += batch_size_stages[1].batch_size * self.world_size + self.assertEqual( + throughput_metric.compute(), + { + "throughput-throughput|total_examples": total_examples, + "throughput-throughput|attempt_examples": total_examples, + "throughput-throughput|batch_size": 512, + }, + ) + + def test_num_batch_without_batch_size_stages(self) -> None: + # Create the module without the batch_size_stages + throughput_metric = ThroughputMetric( + batch_size=self.batch_size, + world_size=self.world_size, + window_seconds=100, + batch_size_stages=None, + ) + + # Make sure num_batch is not present as an argument of the class + self.assertFalse(hasattr(throughput_metric, "num_batch")) + + throughput_metric.update() + state_dict: Dict[str, Any] = throughput_metric.state_dict() + # Ensure num_batch is not included in the state_dict for the module without batch_size_stages + self.assertNotIn("num_batch", state_dict) + + def test_state_dict_load_module_lifecycle(self) -> None: + """ + A test to ensure that the load_state_dict and state_dict hooks correctly handle the num_batch attribute + through the module lifecycle. + """ + + throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + + self.assertTrue(hasattr(throughput_metric, "_num_batch")) + + # Stage 1: create metric and update the state_dict before persisting it + # Update metric, expecting num_batch to be incremented to 1 + throughput_metric.update() + # Ensure num_batch is 1 + self.assertEqual(throughput_metric._num_batch, 1) + # Ensure num_batch is included in the state_dict and has the correct value + state_dict: Dict[str, Any] = throughput_metric.state_dict() + self.assertIn("num_batch", state_dict) + # Ensure num_batch was saved to state_dict with the correct value + self.assertEqual(state_dict["num_batch"].item(), throughput_metric._num_batch) + + # Stage 2: load the state_dict and ensure num_batch is loaded correctly + + # Create a new metric instance + new_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + # Ensure num_batch is 0 + self.assertEqual(new_throughput_metric._num_batch, 0) + # Load the state_dict + new_throughput_metric.load_state_dict(state_dict) + # Ensure num_batch is loaded from the state_dict with the correct value + self.assertEqual(new_throughput_metric._num_batch, 1) + + # Stage 3: update the metric after loading the state and resave the state_dict + + # Save the state_dict + state_dict = new_throughput_metric.state_dict() + # Ensure num_batch is included in the state_dict + self.assertIn("num_batch", state_dict) + # Ensure num_batch was saved to state_dict with the correct value + self.assertEqual( + state_dict["num_batch"].item(), new_throughput_metric._num_batch + ) + + def test_state_dict_hook_adds_key(self) -> None: + """ + Ensures that the state_dict_hook adds the 'num_batch' key to the state_dict + when batch_size_stages is True. + """ + throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + for _ in range(5): + throughput_metric.update() + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() + prefix: str = "test_prefix_" + ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {}) + self.assertIn(f"{prefix}num_batch", state_dict) + self.assertEqual(state_dict[f"{prefix}num_batch"].item(), 5) + + def test_state_dict_hook_no_batch_size_stages(self) -> None: + """ + Verifies that the state_dict_hook does not add the 'num_batch' key when + batch_size_stages is None. + """ + # Hook-only test + throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=None, + ) + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() + prefix: str = "test_prefix_" + ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {}) + self.assertNotIn(f"{prefix}num_batch", state_dict) + + # Lifecycle test + + num_updates = 10 + prev_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=None, + ) + for _ in range(num_updates): + prev_job_throughput_metric.update() + prev_state_dict = prev_job_throughput_metric.state_dict() + + curr_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=None, + ) + + curr_job_throughput_metric.load_state_dict(prev_state_dict) + # Make sure _num_batch is not present as an argument of the class + self.assertFalse(hasattr(curr_job_throughput_metric, "_num_batch")) + + def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_from_bss( + self, + ) -> None: + """ + Checks that the load_state_dict_hook correctly restores the 'num_batch' value + from the state_dict. + """ + num_updates = 10 + prev_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + for _ in range(num_updates): + prev_job_throughput_metric.update() + prev_state_dict = prev_job_throughput_metric.state_dict() + + curr_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(1024, 1), BatchSizeStage(2048, None)], + ) + + curr_job_throughput_metric.load_state_dict(prev_state_dict) + self.assertEqual(curr_job_throughput_metric._num_batch, num_updates) + + def test_load_state_dict_hook_resumes_from_checkpoint_without_bss(self) -> None: + """ + Verifies that the load_state_dict_hook correctly handles the case where a + previously checkpointed job used the batch_size_stages, but a subsequent job, + restored from a checkpoint, isn't using them. + """ + + prev_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + + prev_state_dict = prev_job_throughput_metric.state_dict() + + curr_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=None, # No batch_size_stages + ) + + curr_job_throughput_metric.load_state_dict(prev_state_dict) + + self.assertFalse(hasattr(curr_job_throughput_metric, "_num_batch")) + + def test_load_state_dict_hook_resumes_from_checkpoint_with_bss_without_key( + self, + ) -> None: + """ + Verifies that the load_state_dict_hook correctly handles the case where a + previously checkpointed job didn't use batch_size_stages, but a subsequent job, + restored from a checkpoint, is using them. + """ + prev_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=None, # No batch_size_stages + ) + prev_state_dict = prev_job_throughput_metric.state_dict() + + curr_job_throughput_metric = ThroughputMetric( + batch_size=32, + world_size=4, + window_seconds=100, + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], + ) + + curr_job_throughput_metric.load_state_dict(prev_state_dict) + + # Expecting 0 + self.assertEqual(curr_job_throughput_metric._num_batch, 0) diff --git a/torchrec/metrics/tests/test_tower_qps.py b/torchrec/metrics/tests/test_tower_qps.py new file mode 100644 index 000000000..7dd91010d --- /dev/null +++ b/torchrec/metrics/tests/test_tower_qps.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest +from functools import partial, update_wrapper +from typing import Callable, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed as dist +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.model_utils import parse_task_model_outputs +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo +from torchrec.metrics.test_utils import ( + gen_test_batch, + gen_test_tasks, + metric_test_helper, + rec_metric_value_test_launcher, + TestMetric, +) +from torchrec.metrics.tower_qps import TowerQPSMetric + +WORLD_SIZE = 4 +WARMUP_STEPS = 100 +DURING_WARMUP_NSTEPS = 10 +AFTER_WARMUP_NSTEPS = 120 + + +TestRecMetricOutput = Tuple[ + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], +] + + +class TestTowerQPSMetric(TestMetric): + def __init__( + self, + world_size: int, + rec_tasks: List[RecTaskInfo], + ) -> None: + super().__init__(world_size, rec_tasks) + + # The abstract _get_states method in TestMetric has to be overwritten + # For tower qps the time_lapse state is not generated from labels, predictions + # or weights + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return {} + + @staticmethod + def _reduce(states: Dict[str, List[torch.Tensor]]) -> Dict[str, torch.Tensor]: + reduced_states: Dict[str, torch.Tensor] = {} + # Need to check if states is empty, because we only update the states after warmup + if states: + reduced_states["num_samples"] = torch.sum( + torch.stack(states["num_samples"]), dim=0 + ) + reduced_states["time_lapse"] = torch.max( + torch.stack(states["time_lapse"]), dim=0 + ).values + return reduced_states + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + if "num_samples" not in states or "time_lapse" not in states: + # This is to match the default 0.0 output from TowerQPSMetric if warmup is not done + return torch.tensor(float("0.0"), dtype=torch.double) + + return torch.where( + states["time_lapse"] <= 0.0, + 0.0, + states["num_samples"] / states["time_lapse"], + ).double() + + def compute( + self, + model_outs: List[Dict[str, torch.Tensor]], + nsteps: int, + batch_window_size: int, + timestamps: Optional[List[float]], + ) -> TestRecMetricOutput: + assert timestamps is not None + lifetime_states, window_states, local_lifetime_states, local_window_states = ( + {task_info.name: {} for task_info in self._rec_tasks} for _ in range(4) + ) + for i in range(WARMUP_STEPS, nsteps): + for task_info in self._rec_tasks: + local_states = { + "num_samples": torch.tensor( + model_outs[i][task_info.label_name].shape[-1], + dtype=torch.long, + ), + "time_lapse": torch.tensor( + timestamps[i] - timestamps[i - 1], dtype=torch.double + ), + } + self._aggregate(local_lifetime_states[task_info.name], local_states) + if nsteps - batch_window_size <= i: + self._aggregate(local_window_states[task_info.name], local_states) + + for task_info in self._rec_tasks: + aggregated_lifetime_state = {} + for k, v in local_lifetime_states[task_info.name].items(): + aggregated_lifetime_state[k] = [ + torch.zeros_like(v) for _ in range(self.world_size) + ] + dist.all_gather(aggregated_lifetime_state[k], v) + lifetime_states[task_info.name] = self._reduce(aggregated_lifetime_state) + + aggregated_window_state = {} + for k, v in local_window_states[task_info.name].items(): + aggregated_window_state[k] = [ + torch.zeros_like(v) for _ in range(self.world_size) + ] + dist.all_gather(aggregated_window_state[k], v) + window_states[task_info.name] = self._reduce(aggregated_window_state) + + lifetime_metrics = {} + window_metrics = {} + local_lifetime_metrics = {} + local_window_metrics = {} + for task_info in self._rec_tasks: + lifetime_metrics[task_info.name] = self._compute( + lifetime_states[task_info.name] + ) + window_metrics[task_info.name] = self._compute( + window_states[task_info.name] + ) + local_lifetime_metrics[task_info.name] = self._compute( + local_lifetime_states[task_info.name] + ) + local_window_metrics[task_info.name] = self._compute( + local_window_states[task_info.name] + ) + return ( + lifetime_metrics, + window_metrics, + local_lifetime_metrics, + local_window_metrics, + ) + + +class TowerQPSMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = TowerQPSMetric + task_names: str = "qps" + + _test_tower_qps: Callable[..., None] = partial( + metric_test_helper, + is_time_dependent=True, + time_dependent_metric={TowerQPSMetric: "torchrec.metrics.tower_qps"}, + ) + update_wrapper(_test_tower_qps, metric_test_helper) + + def test_tower_qps_during_warmup_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + ) + + def test_tower_qps_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + test_nsteps=DURING_WARMUP_NSTEPS, + ) + + def test_tower_qps_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + test_nsteps=AFTER_WARMUP_NSTEPS, + ) + + def test_tower_qps_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + test_nsteps=AFTER_WARMUP_NSTEPS, + ) + + def test_check_update_tower_qps_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=True, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + test_nsteps=AFTER_WARMUP_NSTEPS, + ) + + def test_check_update_tower_qps_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TowerQPSMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTowerQPSMetric, + metric_name=TowerQPSMetricTest.task_names, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=True, + world_size=WORLD_SIZE, + entry_point=self._test_tower_qps, + test_nsteps=AFTER_WARMUP_NSTEPS, + ) + + def test_warmup_checkpointing(self) -> None: + warmup_steps = 5 + extra_steps = 2 + batch_size = 128 + qps = TowerQPSMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=[DefaultTaskInfo], + warmup_steps=warmup_steps, + compute_on_all_ranks=False, + should_validate_update=False, + window_size=200, + ) + model_output = gen_test_batch(batch_size) + for i in range(5): + for _ in range(warmup_steps + extra_steps): + qps.update( + predictions={"DefaultTask": model_output["prediction"]}, + labels={"DefaultTask": model_output["label"]}, + weights={"DefaultTask": model_output["weight"]}, + ) + self.assertEqual( + qps._metrics_computations[0].warmup_examples, + batch_size * warmup_steps * (i + 1), + ) + self.assertEqual( + qps._metrics_computations[0].num_examples, + batch_size * (warmup_steps + extra_steps) * (i + 1), + ) + # Mimic trainer crashing and loading a checkpoint. + qps._metrics_computations[0]._steps = 0 + + def test_mtml_empty_update(self) -> None: + warmup_steps = 2 + extra_steps = 2 + batch_size = 128 + task_names = ["t1", "t2"] + tasks = gen_test_tasks(task_names) + qps = TowerQPSMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=tasks, + warmup_steps=warmup_steps, + compute_on_all_ranks=False, + should_validate_update=False, + window_size=200, + ) + for step in range(warmup_steps + extra_steps): + _model_output = [ + gen_test_batch( + label_name=task.label_name, + prediction_name=task.prediction_name, + weight_name=task.weight_name, + batch_size=batch_size, + ) + for task in tasks + ] + model_output = {k: v for d in _model_output for k, v in d.items()} + labels, predictions, weights, _ = parse_task_model_outputs( + tasks, model_output + ) + if step % 2 == 0: + del labels["t1"] + else: + del labels["t2"] + qps.update(predictions=predictions, labels=labels, weights=weights) + self.assertEqual( + qps._metrics_computations[0].num_examples, (step + 1) // 2 * batch_size + ) + self.assertEqual( + qps._metrics_computations[1].num_examples, (step + 2) // 2 * batch_size + ) diff --git a/torchrec/metrics/tests/test_unweighted_ne.py b/torchrec/metrics/tests/test_unweighted_ne.py new file mode 100644 index 000000000..5a18178d0 --- /dev/null +++ b/torchrec/metrics/tests/test_unweighted_ne.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) +from torchrec.metrics.unweighted_ne import ( + compute_cross_entropy, + compute_ne, + UnweightedNEMetric, +) + + +WORLD_SIZE = 4 + + +class TestUnweightedNEMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + # Override the weights to be all ones + weights = torch.ones_like(labels) + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestUnweightedNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + weighted_num_samples = torch.sum(weights) + pos_labels = torch.sum(weights * labels) + neg_labels = torch.sum(weights * (1.0 - labels)) + return { + "cross_entropy_sum": cross_entropy_sum, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + "num_samples": torch.tensor(labels.size()).long(), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + allow_missing_label_with_zero_weight = False + if not states["weighted_num_samples"].all(): + allow_missing_label_with_zero_weight = True + + return compute_ne( + states["cross_entropy_sum"], + states["weighted_num_samples"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + eta=TestUnweightedNEMetric.eta, + allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight, + ) + + +class UnweightedNEMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = UnweightedNEMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "unweighted_ne" + + def test_unweighted_ne_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unweighted_ne_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unweighted_ne_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unweighted_ne_update_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + def test_unweighted_ne_zero_weights(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + zero_weights=True, + ) + + +class UnweightedNEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = UnweightedNEMetric + task_name: str = "unweighted_ne" + + def test_sync_unweighted_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_weighted_avg.py b/torchrec/metrics/tests/test_weighted_avg.py new file mode 100644 index 000000000..226c06748 --- /dev/null +++ b/torchrec/metrics/tests/test_weighted_avg.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_value_test_launcher, + TestMetric, +) +from torchrec.metrics.weighted_avg import get_mean, WeightedAvgMetric + + +WORLD_SIZE = 4 + + +class TestWeightedAvgMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + + return { + "weighted_sum": (predictions * weights).sum(dim=-1), + "weighted_num_samples": weights.sum(dim=-1), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return get_mean(states["weighted_sum"], states["weighted_num_samples"]) + + +class WeightedAvgMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = WeightedAvgMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "weighted_avg" + + def test_weighted_avg_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=WeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestWeightedAvgMetric, + metric_name=WeightedAvgMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_weighted_avg_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=WeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestWeightedAvgMetric, + metric_name=WeightedAvgMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_weighted_avg_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=WeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestWeightedAvgMetric, + metric_name=WeightedAvgMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_weighted_avg_update_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=WeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestWeightedAvgMetric, + metric_name=WeightedAvgMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=WeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestWeightedAvgMetric, + metric_name=WeightedAvgMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Optional[torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.6]]), + "expected_weighted_avg": torch.tensor([0.74]), + }, + # no weight + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": None, + "expected_weighted_avg": torch.tensor([0.5]), + }, + # all weights are zero + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0] * 5]), + "expected_weighted_avg": torch.tensor([float("nan")]), + }, + ] + + +class WeightedAvgValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of weighted avg in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @torch.no_grad() + def _test_weighted_avg_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_weighted_avg: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + if weights is None: + # pyre-ignore + inputs["weights"] = None + else: + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + weighted_avg = WeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + weighted_avg.update(**inputs) + actual_weighted_avg = weighted_avg.compute() + + for task_id, task in enumerate(task_list): + cur_actual_weighted_avg = actual_weighted_avg[ + f"weighted_avg-{task.name}|window_weighted_avg" + ] + cur_expected_weighted_avg = expected_weighted_avg[task_id].unsqueeze(dim=0) + if cur_expected_weighted_avg.isnan().any(): + self.assertTrue(cur_actual_weighted_avg.isnan().any()) + else: + torch.testing.assert_close( + cur_actual_weighted_avg, + cur_expected_weighted_avg, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_weighted_avg}, Expected: {cur_expected_weighted_avg}", + ) + + def test_weighted_avg(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + # pyre-ignore + self._test_weighted_avg_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise diff --git a/torchrec/metrics/tests/test_xauc.py b/torchrec/metrics/tests/test_xauc.py new file mode 100644 index 000000000..a5263150c --- /dev/null +++ b/torchrec/metrics/tests/test_xauc.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.xauc import XAUCMetric + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +def generate_model_output() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), + "labels": torch.tensor([[0.2, 0.1, 0.3, 0.5, 0.25]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0]]), + "expected_xauc": torch.tensor([0.6667]), + } + + +class XAUCMetricTest(unittest.TestCase): + def setUp(self) -> None: + self.xauc = XAUCMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + ) + + def test_xauc(self) -> None: + model_output = generate_model_output() + self.xauc.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + ) + metric = self.xauc.compute() + actual_metric = metric[f"xauc-{DefaultTaskInfo.name}|lifetime_xauc"] + expected_metric = model_output["expected_xauc"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) diff --git a/torchrec/metrics/throughput.py b/torchrec/metrics/throughput.py index da1f022ee..758250426 100644 --- a/torchrec/metrics/throughput.py +++ b/torchrec/metrics/throughput.py @@ -5,16 +5,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 +import copy import logging import math import time -from collections import deque -from typing import Deque, Dict +from collections import deque, OrderedDict +from typing import Any, Deque, Dict, List, Optional import torch import torch.nn as nn +from torchrec.distributed.utils import none_throws +from torchrec.metrics.metrics_config import BatchSizeStage from torchrec.metrics.metrics_namespace import ( compose_metric_key, MetricName, @@ -59,7 +64,6 @@ class ThroughputMetric(nn.Module): _namespace: MetricNamespace = MetricNamespace.THROUGHPUT _metric_name: MetricName = MetricName.THROUGHPUT - _batch_examples: int _window_seconds: int _warmup_steps: int _window_time_lapse_buffer: Deque[float] @@ -67,7 +71,10 @@ class ThroughputMetric(nn.Module): _previous_ts: float _lifetime_throughput_key: str _window_throughput_key: str + _attempt_throughput_key: str _total_examples_key: str + _attempt_examples_key: str + _batch_size_key: str _steps: int def __init__( @@ -77,6 +84,7 @@ def __init__( world_size: int, window_seconds: int, warmup_steps: int = 100, + batch_size_stages: Optional[List[BatchSizeStage]] = None, ) -> None: super().__init__() if window_seconds < 1: @@ -96,9 +104,20 @@ def __init__( ) window_seconds = MAX_WINDOW_TS - self._batch_examples = batch_size * world_size + self._batch_size = batch_size + self._world_size = world_size self._window_seconds = window_seconds self._warmup_steps = warmup_steps + self._batch_size_stages: Optional[List[BatchSizeStage]] = copy.deepcopy( + batch_size_stages + ) + + if self._batch_size_stages is not None: + # Keep track of the number of batches if using batch_size_stages + self._num_batch: int = 0 + + self._register_load_state_dict_pre_hook(self.load_state_dict_hook) + self.register_state_dict_post_hook(self.state_dict_hook) self.register_buffer("total_examples", torch.tensor(0, dtype=torch.long)) self.register_buffer("warmup_examples", torch.tensor(0, dtype=torch.long)) @@ -106,6 +125,20 @@ def __init__( "time_lapse_after_warmup", torch.tensor(0, dtype=torch.double) ) + self.register_buffer( + "attempt_examples", torch.tensor(0, dtype=torch.long), persistent=False + ) + self.register_buffer( + "attempt_warmup_examples", + torch.tensor(0, dtype=torch.long), + persistent=False, + ) + self.register_buffer( + "attempt_time_lapse_after_warmup", + torch.tensor(0, dtype=torch.double), + persistent=False, + ) + self._window_time_lapse_buffer = deque(maxlen=MAX_WINDOW_TS) self._window_time_lapse = 0 self._previous_ts = 0 @@ -122,14 +155,55 @@ def __init__( self._metric_name, MetricPrefix.WINDOW, ) + self._attempt_throughput_key = compose_metric_key( + self._namespace, + str(self._namespace), + self._metric_name, + MetricPrefix.ATTEMPT, + ) self._total_examples_key = compose_metric_key( self._namespace, str(self._namespace), MetricName.TOTAL_EXAMPLES, ) - + self._attempt_examples_key = compose_metric_key( + self._namespace, + str(self._namespace), + MetricName.ATTEMPT_EXAMPLES, + ) + self._batch_size_key = compose_metric_key( + self._namespace, + str(self._namespace), + MetricName.BATCH_SIZE, + ) self._steps = 0 + def _get_batch_size(self) -> int: + # No batch size stages, use the default batch size + if not self._batch_size_stages: + return self._batch_size + + # Get batch size from batch_size_stages + assert self._num_batch is not None, "num_batch should not be None" + batch_size_stages = none_throws(self._batch_size_stages) + while self._batch_size_stages: + stage = self._batch_size_stages[0] + # Reach the last stage + if stage.max_iters is None: + assert len(batch_size_stages) == 1 + return stage.batch_size + # This stage finished + if stage.max_iters < self._num_batch: + batch_size_stages.pop(0) + # Move to the next stage + continue + # In this stage + return stage.batch_size + raise AssertionError("Unreachable, batch_size_stages should always has 1 item") + + def _batch_examples(self) -> int: + return self._get_batch_size() * self._world_size + def _check_window(self) -> None: while self._window_time_lapse > self._window_seconds: self._window_time_lapse -= self._window_time_lapse_buffer.popleft() @@ -137,32 +211,44 @@ def _check_window(self) -> None: def update(self) -> None: ts = time.monotonic() self._steps += 1 - self.total_examples += self._batch_examples + if self._batch_size_stages is not None: + self._num_batch += 1 + batch_examples = self._batch_examples() + self.total_examples += batch_examples + self.attempt_examples += batch_examples if self._steps <= self._warmup_steps: - self.warmup_examples += self._batch_examples + self.warmup_examples += batch_examples + self.attempt_warmup_examples += batch_examples if self._steps == self._warmup_steps: self._previous_ts = ts else: time_lapse = ts - self._previous_ts self.time_lapse_after_warmup += time_lapse + self.attempt_time_lapse_after_warmup += time_lapse self._window_time_lapse += time_lapse self._window_time_lapse_buffer.append(time_lapse) self._check_window() self._previous_ts = ts def compute(self) -> Dict[str, torch.Tensor]: - ret = {self._total_examples_key: self.total_examples} + ret = { + self._total_examples_key: self.total_examples, + self._attempt_examples_key: self.attempt_examples, + } if self._steps > self._warmup_steps and ( not math.isclose(self.time_lapse_after_warmup.item(), 0) ): lifetime_throughput = ( self.total_examples - self.warmup_examples ) / self.time_lapse_after_warmup + attempt_throughput = ( + self.attempt_examples - self.attempt_warmup_examples + ) / self.attempt_time_lapse_after_warmup if not math.isclose(self._window_time_lapse, 0): window_throughput = ( len(self._window_time_lapse_buffer) - * self._batch_examples + * self._batch_examples() / self._window_time_lapse ) else: @@ -170,12 +256,59 @@ def compute(self) -> Dict[str, torch.Tensor]: if not math.isclose(lifetime_throughput.item(), 0): ret.update( { - self._lifetime_throughput_key: torch.tensor( - lifetime_throughput, dtype=torch.double - ), + self._lifetime_throughput_key: lifetime_throughput.clone().detach(), self._window_throughput_key: torch.tensor( window_throughput, dtype=torch.double ), } ) + if not math.isclose(attempt_throughput.item(), 0): + ret.update( + { + self._attempt_throughput_key: attempt_throughput.clone().detach(), + } + ) + # If using batch_size_stages, also report the current batch size + # that was used for the throughput calculation + if self._batch_size_stages is not None: + ret.update( + { + self._batch_size_key: torch.tensor( + self._get_batch_size(), dtype=torch.int32 + ), + } + ) + return ret + + @staticmethod + def state_dict_hook( + module: nn.Module, + state_dict: OrderedDict[str, torch.Tensor], + prefix: str, + local_metadata: Dict[str, Any], + ) -> None: + if module._batch_size_stages is not None: + # Save the number of batches used for the throughput calculation to the state dict + num_batch_key = f"{prefix}num_batch" + state_dict[num_batch_key] = torch.tensor( + module._num_batch, dtype=torch.long + ) + + def load_state_dict_hook( + self, + state_dict: OrderedDict[str, torch.Tensor], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + key = f"{prefix}num_batch" + if key in state_dict: + # If present, pop the number of batches used for the throughput calculation from the state dict + num_batch_tensor = state_dict.pop(key) + # Apply the number of batches to the module if using batch_size_stages + if self._batch_size_stages is not None: + self._num_batch = int(num_batch_tensor.item()) diff --git a/torchrec/metrics/tower_qps.py b/torchrec/metrics/tower_qps.py new file mode 100644 index 000000000..4411dcba4 --- /dev/null +++ b/torchrec/metrics/tower_qps.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import time +from typing import Any, cast, Dict, List, Optional, Type + +import torch +import torch.distributed as dist + +from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, + RecModelOutput, +) + + +WARMUP_STEPS = 100 + +NUM_EXAMPLES = "num_examples" +WARMUP_EXAMPLES = "warmup_examples" +TIME_LAPSE = "time_lapse" + + +def _compute_tower_qps( + num_examples: torch.Tensor, time_lapse: torch.Tensor +) -> torch.Tensor: + return torch.where(time_lapse <= 0.0, 0.0, num_examples / time_lapse).double() + + +def _max_reduction(state: torch.Tensor) -> torch.Tensor: + return torch.max(state, dim=0).values + + +class TowerQPSMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for tower QPS. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + _previous_ts: float + _steps: int + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._warmup_steps: int = kwargs.pop("warmup_steps") + super().__init__(*args, **kwargs) + self._add_state( + NUM_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + WARMUP_EXAMPLES, + torch.zeros(self._n_tasks, dtype=torch.long), + add_window_state=False, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + TIME_LAPSE, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx=_max_reduction, + persistent=True, + ) + self._previous_ts = 0 + self._steps = 0 + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + self._steps += 1 + num_examples_scalar = labels.shape[-1] + num_examples = torch.tensor(num_examples_scalar, dtype=torch.long) + self_num_examples = getattr(self, NUM_EXAMPLES) + self_num_examples += num_examples + ts = time.monotonic() + if self._steps <= self._warmup_steps: + self_warmup_examples = getattr(self, WARMUP_EXAMPLES) + self_warmup_examples += num_examples + if self._steps == self._warmup_steps: + self._previous_ts = ts + else: + self._aggregate_window_state( + NUM_EXAMPLES, num_examples, num_examples_scalar + ) + time_lapse = torch.tensor(ts - self._previous_ts, dtype=torch.double) + self_time_lapse = getattr(self, TIME_LAPSE) + self_time_lapse += time_lapse + self._aggregate_window_state(TIME_LAPSE, time_lapse, num_examples_scalar) + self._previous_ts = ts + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.TOWER_QPS, + metric_prefix=MetricPrefix.LIFETIME, + value=_compute_tower_qps( + cast(torch.Tensor, self.num_examples) + - cast(torch.Tensor, self.warmup_examples), + cast(torch.Tensor, self.time_lapse), + ), + ), + MetricComputationReport( + name=MetricName.TOWER_QPS, + metric_prefix=MetricPrefix.WINDOW, + value=_compute_tower_qps( + self.get_window_state(NUM_EXAMPLES), + self.get_window_state(TIME_LAPSE), + ), + ), + MetricComputationReport( + name=MetricName.TOTAL_EXAMPLES, + metric_prefix=MetricPrefix.DEFAULT, + value=cast(torch.Tensor, self.num_examples).detach(), + ), + ] + + +class TowerQPSMetric(RecMetric): + r""" + TowerQPSMetric defines the tower QPS metric. + Tower QPS's formula is training example count / time + where training example count = sum(examples for trainer 1, ... examples for trainer n) + and time = max(time for trainer 1, ... time for trainer n) + It's mostly used for cases where there's no fixed batch size + For example for Pyper MTML models, given the same input, different tasks may have + different numbers of examples to process + + Args: + world_size (int): the number of trainers. + my_rank (int): the rank of this trainer. + batch_size (int): batch size used by this trainer. + tasks (List[RecTaskInfo]): the information of the model tasks. + compute_mode (RecComputeMode): the computation mode. See RecComputeMode. + window_size (int): the window size for the window metric. + fused_update_limit (int): the maximum number of updates to be fused. + process_group (Optional[ProcessGroup]): the process group used for the + communication. Will use the default process group if not specified. + + Call Args: + Not supported. + + Returns: + Not supported. + + Example:: + + For world_size = 4, suppose we have 1 step after warmup + predictions = [ + [0.8033, 0.0662, 0.7559], + [0.1821, 0.9652, 0.4602], + [0.8545, 0.4758, 0.2220], + [0.1021, 0.2469, 0.7259], + ], + previous_ts = [278.94, 312.16, 286.96, 291.43] + ts = [281.35, 316.45, 289.47, 295.55] + + num_examples = [3, 3, 3, 3] + time_lapse = [2.41, 4.29, 2.51, 4.12] + + tower_qps = torch.sum(num_examples) / torch.max(time_lapse) = 2.80 + """ + + _namespace: MetricNamespace = MetricNamespace.TOWER_QPS + _computation_class: Type[RecMetricComputation] = TowerQPSMetricComputation + + def __init__( + self, + world_size: int, + my_rank: int, + batch_size: int, + tasks: List[RecTaskInfo], + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size: int = 100, + fused_update_limit: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + warmup_steps: int = WARMUP_STEPS, + **kwargs: Any, + ) -> None: + if fused_update_limit > 0: + raise RecMetricException("Fused update is not supported for tower QPS") + + kwargs["warmup_steps"] = warmup_steps + + super().__init__( + world_size=world_size, + my_rank=my_rank, + batch_size=batch_size, + tasks=tasks, + compute_mode=compute_mode, + window_size=window_size, + fused_update_limit=fused_update_limit, + process_group=process_group, + **kwargs, + ) + + def update( + self, + *, + predictions: Optional[RecModelOutput], + labels: RecModelOutput, + weights: Optional[RecModelOutput], + **kwargs: Dict[str, Any], + ) -> None: + with torch.no_grad(): + if self._compute_mode in [ + RecComputeMode.FUSED_TASKS_COMPUTATION, + RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + ]: + if not isinstance(labels, torch.Tensor): + raise RecMetricException( + "Fused computation only support where 'labels' is a tensor" + ) + labels = labels.view(-1, self._batch_size) + if self._should_validate_update: + # Set the default value to be all True. When weights is None, it's considered + # to be a valid input, and we'll use the default value + has_valid_weights = torch.ones( + len(self._tasks), + dtype=torch.bool, + device=self._metrics_computations[0].has_valid_update.device, + ) + if weights is not None: + if not isinstance(weights, torch.Tensor): + raise RecMetricException( + "Fused computation only support where 'weights' is a tensor" + ) + has_valid_weights = torch.gt( + torch.count_nonzero( + weights.view(-1, self._batch_size), dim=-1 + ), + 0, + ) + + if torch.any(has_valid_weights): + self._metrics_computations[0].update( + predictions=None, labels=labels, weights=None + ) + self._metrics_computations[0].has_valid_update.logical_or_( + has_valid_weights + ) + else: + self._metrics_computations[0].update( + predictions=None, labels=labels, weights=None + ) + else: + for task, metric_ in zip(self._tasks, self._metrics_computations): + if task.name not in labels: + continue + # pyre-fixme[6]: For 1st argument expected `Union[None, + # List[typing.Any], int, slice, Tensor, typing.Tuple[typing.Any, + # ...]]` but got `str`. + task_labels = labels[task.name].view(1, -1) + if self._should_validate_update: + has_valid_weights = torch.ones( + 1, dtype=torch.bool, device=metric_.has_valid_update.device + ) + if weights is not None and task.name in weights: + has_valid_weights = torch.gt( + torch.count_nonzero( + # pyre-fixme[6]: For 1st argument expected + # `Union[None, List[typing.Any], int, slice, + # Tensor, typing.Tuple[typing.Any, ...]]` but got + # `str`. + weights[task.name].view(1, -1), + dim=-1, + ), + 0, + ) + if has_valid_weights[0]: + metric_.has_valid_update.logical_or_(has_valid_weights) + else: + continue + metric_.update( + predictions=None, + labels=task_labels, + weights=None, + ) diff --git a/torchrec/metrics/unweighted_ne.py b/torchrec/metrics/unweighted_ne.py new file mode 100644 index 000000000..74f77ce9b --- /dev/null +++ b/torchrec/metrics/unweighted_ne.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) +from torchrec.pt2.utils import pt2_compile_callable + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def get_unweighted_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> Dict[str, torch.Tensor]: + # Allow for unweighted NE computation by passing in a weights tensor of all ones + weights = torch.ones_like(labels) + cross_entropy = compute_cross_entropy( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_sum": torch.sum(cross_entropy, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + } + + +class UnweightedNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Unweighted NE, i.e. Normalized Entropy. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + allow_missing_label_with_zero_weight (bool): allow missing label to have weight 0, instead of throwing exception. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + @pt2_compile_callable + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for UnweightedNEMetricComputation update. Weight will not be used for this metric." + ) + states = get_unweighted_ne_states( + labels, + predictions, + weights, + self.eta, + ) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.UNWEIGHTED_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_ne( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.UNWEIGHTED_NE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_ne( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class UnweightedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.UNWEIGHTED_NE + _computation_class: Type[RecMetricComputation] = UnweightedNEMetricComputation diff --git a/torchrec/metrics/weighted_avg.py b/torchrec/metrics/weighted_avg.py new file mode 100644 index 000000000..4d8466a3a --- /dev/null +++ b/torchrec/metrics/weighted_avg.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, +) + + +def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor: + return value_sum / num_samples + + +class WeightedAvgMetricComputation(RecMetricComputation): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "weighted_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + num_samples = labels.shape[0] + predictions = cast(torch.Tensor, predictions) + weights = cast(torch.Tensor, weights) + states = { + "weighted_sum": (predictions * weights).sum(dim=-1), + "weighted_num_samples": weights.sum(dim=-1), + } + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.WEIGHTED_AVG, + metric_prefix=MetricPrefix.LIFETIME, + value=get_mean( + cast(torch.Tensor, self.weighted_sum), + cast(torch.Tensor, self.weighted_num_samples), + ), + ), + MetricComputationReport( + name=MetricName.WEIGHTED_AVG, + metric_prefix=MetricPrefix.WINDOW, + value=get_mean( + self.get_window_state("weighted_sum"), + self.get_window_state("weighted_num_samples"), + ), + ), + ] + + +class WeightedAvgMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG + _computation_class: Type[RecMetricComputation] = WeightedAvgMetricComputation diff --git a/torchrec/metrics/xauc.py b/torchrec/metrics/xauc.py new file mode 100644 index 000000000..3e27adc0d --- /dev/null +++ b/torchrec/metrics/xauc.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +ERROR_SUM = "error_sum" +WEIGHTED_NUM_PAIRS = "weighted_num_pairs" + + +def compute_xauc( + error_sum: torch.Tensor, weighted_num_pairs: torch.Tensor +) -> torch.Tensor: + return torch.where( + weighted_num_pairs == 0.0, 0.0, error_sum / weighted_num_pairs + ).double() + + +def compute_error_sum( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> torch.Tensor: + predictions = predictions.double() + + errors = [] + for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): + preds_x, preds_y = torch.meshgrid(predictions_i, predictions_i) + labels_x, labels_y = torch.meshgrid(labels_i, labels_i) + weights_x, weights_y = torch.meshgrid(weights_i, weights_i) + weights_flag = weights_x * weights_y + match = torch.logical_or( + torch.logical_and(preds_x > preds_y, labels_x > labels_y), + torch.logical_and(preds_x < preds_y, labels_x < labels_y), + ) + match = ( + weights_flag + * torch.logical_or( + match, torch.logical_and(preds_x == preds_y, labels_x == labels_y) + ).double() + ) + errors.append(torch.sum(torch.triu(match, diagonal=1)).view(1)) + + return torch.cat(errors) + + +def compute_weighted_num_pairs(weights: torch.Tensor) -> torch.Tensor: + num_pairs = [] + for weight_i in weights: + weights_x, weights_y = torch.meshgrid(weight_i, weight_i) + weights_flag = weights_x * weights_y + num_pairs.append(torch.sum(torch.triu(weights_flag, diagonal=1)).view(1)) + + return torch.cat(num_pairs) + + +def get_xauc_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor +) -> Dict[str, torch.Tensor]: + return { + "error_sum": compute_error_sum(labels, predictions, weights), + "weighted_num_pairs": compute_weighted_num_pairs(weights), + } + + +class XAUCMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for XAUC. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "error_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_pairs", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for XAUCMetricComputation update" + ) + states = get_xauc_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + return [ + MetricComputationReport( + name=MetricName.XAUC, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_xauc( + cast(torch.Tensor, self.error_sum), + cast(torch.Tensor, self.weighted_num_pairs), + ), + ), + MetricComputationReport( + name=MetricName.XAUC, + metric_prefix=MetricPrefix.WINDOW, + value=compute_xauc( + self.get_window_state(ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_PAIRS), + ), + ), + ] + + +class XAUCMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.XAUC + _computation_class: Type[RecMetricComputation] = XAUCMetricComputation diff --git a/torchrec/models/__init__.py b/torchrec/models/__init__.py index 13b49bcfb..f0209b073 100644 --- a/torchrec/models/__init__.py +++ b/torchrec/models/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Models Torchrec provides the architecture for two popular recsys models; @@ -24,4 +26,4 @@ * num_features: number of dense features """ -from . import deepfm, dlrm # noqa # noqa +from torchrec.models import deepfm, dlrm # noqa diff --git a/torchrec/models/deepfm.py b/torchrec/models/deepfm.py index 683e79246..322992f89 100644 --- a/torchrec/models/deepfm.py +++ b/torchrec/models/deepfm.py @@ -5,13 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import List import torch from torch import nn -from torchrec import EmbeddingBagCollection, KeyedJaggedTensor from torchrec.modules.deepfm import DeepFM, FactorizationMachine -from torchrec.sparse.jagged_tensor import KeyedTensor +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class SparseArch(nn.Module): @@ -82,8 +84,12 @@ class DenseArch(nn.Module): B = 20 D = 3 in_features = 10 - dense_arch = DenseArch(in_features=10, hidden_layer_size=10, embedding_dim=D) - dense_embedded = dense_arch(torch.rand((B, 10))) + dense_arch = DenseArch( + in_features=in_features, hidden_layer_size=10, embedding_dim=D + ) + + dense_arch_input = torch.rand((B, in_features)) + dense_embedded = dense_arch(dense_arch_input) """ def __init__( diff --git a/torchrec/models/dlrm.py b/torchrec/models/dlrm.py index b415e8904..ad1975eba 100644 --- a/torchrec/models/dlrm.py +++ b/torchrec/models/dlrm.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Dict, List, Optional, Tuple import torch @@ -50,7 +52,7 @@ class SparseArch(nn.Module): name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + embedding_bag_collection = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) sparse_arch = SparseArch(embedding_bag_collection) # 0 1 2 <-- batch @@ -98,14 +100,12 @@ def forward( sparse_features: KeyedTensor = self.embedding_bag_collection(features) - B: int = features.stride() - sparse: Dict[str, torch.Tensor] = sparse_features.to_dict() sparse_values: List[torch.Tensor] = [] for name in self.sparse_feature_names: sparse_values.append(sparse[name]) - return torch.cat(sparse_values, dim=1).reshape(B, self.F, self.D) + return torch.cat(sparse_values, dim=1).reshape(-1, self.F, self.D) @property def sparse_feature_names(self) -> List[str]: @@ -185,8 +185,10 @@ class InteractionArch(nn.Module): def __init__(self, num_sparse_features: int) -> None: super().__init__() self.F: int = num_sparse_features - self.triu_indices: torch.Tensor = torch.triu_indices( - self.F + 1, self.F + 1, offset=1 + self.register_buffer( + "triu_indices", + torch.triu_indices(self.F + 1, self.F + 1, offset=1), + persistent=False, ) def forward( @@ -470,21 +472,21 @@ class DLRM(nn.Module): D = 8 eb1_config = EmbeddingBagConfig( - name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( - name="t2", - embedding_dim=D, - num_embeddings=100, - feature_names=["f2"], + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) model = DLRM( - embedding_bag_collection=ebc, - dense_in_features=100, - dense_arch_layer_sizes=[20], - over_arch_layer_sizes=[5, 1], + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) @@ -495,14 +497,14 @@ class DLRM(nn.Module): # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f3"], - values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), - offsets=torch.tensor([0, 2, 4, 6, 8]), + keys=["f1", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( - dense_features=features, - sparse_features=sparse_features, + dense_features=features, + sparse_features=sparse_features, ) """ diff --git a/torchrec/models/experimental/__init__.py b/torchrec/models/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/models/experimental/test_transformerdlrm.py b/torchrec/models/experimental/test_transformerdlrm.py new file mode 100644 index 000000000..07f49ecf2 --- /dev/null +++ b/torchrec/models/experimental/test_transformerdlrm.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from torchrec.datasets.utils import Batch +from torchrec.models.dlrm import DLRMTrain +from torchrec.models.experimental.transformerdlrm import ( + DLRM_Transformer, + InteractionTransformerArch, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class InteractionArchTransformerTest(unittest.TestCase): + def test_basic(self) -> None: + D = 8 + B = 10 + # multi-head attentions + nhead = 8 + ntransformer_layers = 4 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionTransformerArch( + num_sparse_features=F, + embedding_dim=D, + nhead=nhead, + ntransformer_layers=ntransformer_layers, + ) + dense_features = torch.rand((B, D)) + sparse_features = torch.rand((B, F, D)) + concat_dense = inter_arch(dense_features, sparse_features) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D * (F + 1))) + + def test_larger(self) -> None: + D = 16 + B = 20 + # multi-head attentions + nhead = 8 + ntransformer_layers = 4 + keys = ["f1", "f2", "f3", "f4"] + F = len(keys) + inter_arch = InteractionTransformerArch( + num_sparse_features=F, + embedding_dim=D, + nhead=nhead, + ntransformer_layers=ntransformer_layers, + ) + dense_features = torch.rand((B, D)) + sparse_features = torch.rand((B, F, D)) + concat_dense = inter_arch(dense_features, sparse_features) + self.assertEqual(concat_dense.size(), (B, D * (F + 1))) + + def test_correctness_disabled_in_oss_compatibility(self) -> None: + D = 4 + B = 3 + # multi-head attentions + nhead = 4 + ntransformer_layers = 4 + keys = [ + "f1", + "f2", + "f3", + "f4", + ] + F = len(keys) + # place the manual_seed before the InteractionTransformerArch object to generate the same initialization random values in the Transformer + torch.manual_seed(0) + inter_arch = InteractionTransformerArch( + num_sparse_features=F, + embedding_dim=D, + nhead=nhead, + ntransformer_layers=ntransformer_layers, + ) + dense_features = torch.rand((B, D)) + sparse_features = torch.rand((B, F, D)) + concat_dense = inter_arch(dense_features, sparse_features) + self.assertEqual(concat_dense.size(), (B, D * (F + 1))) + expected = torch.tensor( + [ + [ + -0.4411, + 0.2487, + -1.2685, + 1.4610, + 1.3110, + 0.5152, + -0.4960, + -1.3303, + -0.3962, + -0.0623, + -1.1371, + 1.5956, + 0.2431, + -1.6820, + 0.5242, + 0.9148, + 1.3033, + 0.6409, + -0.9577, + -0.9866, + ], + [ + -1.0850, + -0.0366, + -0.4862, + 1.6078, + 1.1254, + -0.9989, + -0.9927, + 0.8661, + -0.1704, + 1.0223, + -1.5580, + 0.7060, + -0.3081, + -1.3686, + 0.2788, + 1.3979, + 0.0328, + 1.5470, + -0.3670, + -1.2128, + ], + [ + -1.5917, + -0.0995, + 0.7302, + 0.9609, + 0.6606, + 1.0238, + -0.1017, + -1.5827, + -0.6761, + -1.0771, + 0.2262, + 1.5269, + -0.5671, + -1.2114, + 1.4503, + 0.3281, + -0.6540, + -1.2925, + 0.9134, + 1.0331, + ], + ] + ) + self.assertTrue( + torch.allclose( + concat_dense, + expected, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_numerical_stability_disabled_in_oss_compatibility(self) -> None: + D = 4 + B = 3 + # multi-head attentions + nhead = 4 + ntransformer_layers = 4 + keys = ["f1", "f2"] + F = len(keys) + torch.manual_seed(0) + inter_arch = InteractionTransformerArch( + num_sparse_features=F, + embedding_dim=D, + nhead=nhead, + ntransformer_layers=ntransformer_layers, + ) + dense_features = 10 * torch.rand(B, D) + sparse_features = 10 * torch.rand(B, F, D) + concat_dense = inter_arch(dense_features, sparse_features) + expected = torch.LongTensor( + [ + [0, 1, -1, 0, 0, 0, 0, -1, 0, 0, -1, 1], + [0, 0, 0, 1, -1, 0, 1, 0, 0, 0, 1, 0], + [-1, 0, 0, 0, 0, 0, -1, 1, 0, 1, -1, 0], + ] + ) + self.assertTrue(torch.equal(concat_dense.long(), expected)) + + +class DLRMTransformerTest(unittest.TestCase): + def test_basic_disabled_in_oss_compatibility(self) -> None: + torch.manual_seed(0) + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_nn = DLRM_Transformer( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + logits = sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + expected_logits = torch.tensor([[-0.2593], [-0.1043]]) + self.assertTrue( + torch.allclose( + logits, + expected_logits, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_one_sparse(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config]) + sparse_nn = DLRM_Transformer( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f2"], + values=torch.tensor(range(3)), + offsets=torch.tensor([0, 2, 3]), + ) + logits = sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + def test_no_sparse(self) -> None: + ebc = EmbeddingBagCollection(tables=[]) + D_unused = 1 + with self.assertRaises(AssertionError): + DLRM_Transformer( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D_unused], + over_arch_layer_sizes=[5, 1], + ) + + +class DLRMTransformerTrainTest(unittest.TestCase): + def test_basic(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config]) + dlrm_module = DLRM_Transformer( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + dlrm = DLRMTrain(dlrm_module) + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f2"], + values=torch.tensor(range(3)), + offsets=torch.tensor([0, 2, 3]), + ) + batch = Batch( + dense_features=features, + sparse_features=sparse_features, + labels=torch.randint(2, (B,)), + ) + _, (_, logits, _) = dlrm(batch) + self.assertEqual(logits.size(), (B,)) diff --git a/torchrec/models/experimental/transformerdlrm.py b/torchrec/models/experimental/transformerdlrm.py new file mode 100644 index 000000000..9bbec1ae0 --- /dev/null +++ b/torchrec/models/experimental/transformerdlrm.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Optional + +import torch +from torch import nn +from torchrec.models.dlrm import DLRM, OverArch +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class InteractionTransformerArch(nn.Module): + """ + Processes the output of both `SparseArch` (sparse_features) and `DenseArch` + (dense_features). Returns the output of the nn.transformerencoder, + that takes the combined values of both sparse features and the output of the dense layer, + and the dense layer itself (i.e. concat(dense layer output, transformer encoder output). + Note: This model is for benchmarking purposes only, i.e. to measure the performance of transformer + embeddings using the dlrm models. + It is not intended to increase model convergence metrics. + Implemented TE as described here: + https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html?highlight=transformer+encoder#torch.nn.TransformerEncoder + BERT Transformer Paper: https://arxiv.org/abs/1810.04805 + Attention is All you Need: https://arxiv.org/abs/1706.03762 + + + .. note:: + The dimensionality of the `dense_features` (D) is expected to match the + dimensionality of the `sparse_features` so that the dot products between them + can be computed. + Args: + num_sparse_features (int): F. + embedding_dim: int, + nhead: int, #number of attention heads + ntransformer_layers: int, #number of transformer layers. + Example:: + D = 8 #must divisible by number of transformer heads + B = 10 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionTransormerArch(num_sparse_features=len(keys)) + dense_features = torch.rand((B, D)) + sparse_features = torch.rand((B, F, D)) + # B X (D * (F + 1)) + concat_dense = inter_arch(dense_features, sparse_features) + """ + + def __init__( + self, + num_sparse_features: int, + embedding_dim: int, + nhead: int = 8, + ntransformer_layers: int = 4, + ) -> None: + super().__init__() + self.F: int = num_sparse_features + self.nhead = nhead + self.ntransformer_layers = ntransformer_layers + transformer_encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=self.nhead, + ) + self.interarch_TE = nn.TransformerEncoder( + transformer_encoder_layer, num_layers=self.ntransformer_layers + ) + + def forward( + self, dense_features: torch.Tensor, sparse_features: torch.Tensor + ) -> torch.Tensor: + """ + Args: + dense_features (torch.Tensor): an input tensor of size B X D. + sparse_features (torch.Tensor): an input tensor of size B X F X D. + Returns: + torch.Tensor: an output tensor of size B X (D + F + F choose 2). + """ + if self.F <= 0: + return dense_features + (B, D) = dense_features.shape + combined_values = torch.cat( + (dense_features.unsqueeze(1), sparse_features), dim=1 + ) + # Transformer for Interactions + transformer_interactions = self.interarch_TE(combined_values) + interactions_flat = torch.reshape(transformer_interactions, (B, -1)) + return interactions_flat + + +class DLRM_Transformer(DLRM): + """ + Recsys model from "Deep Learning Recommendation Model for Personalization and + Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse + features by learning pooled embeddings for each feature. On the interaction layer, + the relationship between dense features and sparse features is learned through a transformer encoder layer + https://arxiv.org/abs/1706.03762. + The module assumes all sparse features have the same embedding dimension + (i.e. each EmbeddingBagConfig uses the same embedding_dim). + The following notation is used throughout the documentation for the models: + * F: number of sparse features + * D: embedding_dimension of sparse features + * B: batch size + * num_features: number of dense features + Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define `SparseArch`. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (List[int]): the layer sizes for the `DenseArch`. + over_arch_layer_sizes (List[int]): the layer sizes for the `OverArch`. + The output dimension of the `InteractionArch` should not be manually + specified here. + nhead: int: Number of multi-attention heads + ntransformer_layers: int: Number of transformer encoder layers + dense_device (Optional[torch.device]): default compute device. + Example:: + B = 2 + D = 8 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + model = DLRM_Transformer( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20], + over_arch_layer_sizes=[5, 1], + ) + features = torch.rand((B, 100)) + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + logits = model( + dense_features=features, + sparse_features=sparse_features, + ) + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + nhead: int = 8, + ntransformer_layers: int = 4, + dense_device: Optional[torch.device] = None, + ) -> None: + # initialize DLRM + # sparse arch and dense arch are initialized via DLRM + super().__init__( + embedding_bag_collection, + dense_in_features, + dense_arch_layer_sizes, + over_arch_layer_sizes, + dense_device, + ) + embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[ + 0 + ].embedding_dim + num_sparse_features: int = len(self.sparse_arch.sparse_feature_names) + self.inter_arch = InteractionTransformerArch( + num_sparse_features=num_sparse_features, + embedding_dim=embedding_dim, + nhead=nhead, + ntransformer_layers=ntransformer_layers, + ) + over_in_features: int = (num_sparse_features + 1) * embedding_dim + self.over_arch = OverArch( + in_features=over_in_features, + layer_sizes=over_arch_layer_sizes, + device=dense_device, + ) diff --git a/torchrec/models/tests/test_deepfm.py b/torchrec/models/tests/test_deepfm.py index c3f4cd37b..154b4c1c3 100644 --- a/torchrec/models/tests/test_deepfm.py +++ b/torchrec/models/tests/test_deepfm.py @@ -5,17 +5,40 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch from torch.testing import FileCheck # @manual from torchrec.fx import symbolic_trace, Tracer -from torchrec.models.deepfm import FMInteractionArch, SimpleDeepFMNN +from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +class DenseArchTest(unittest.TestCase): + def test_basic(self) -> None: + torch.manual_seed(0) + + B = 20 + D = 3 + in_features = 10 + dense_arch = DenseArch( + in_features=in_features, hidden_layer_size=10, embedding_dim=D + ) + + dense_arch_input = torch.rand((B, in_features)) + dense_embedded = dense_arch(dense_arch_input) + self.assertEqual(dense_embedded.size(), (B, D)) + + # check tracer compatibility + gm = torch.fx.GraphModule(dense_arch, Tracer().trace(dense_arch)) + script = torch.jit.script(gm) + script(dense_arch_input) + + class FMInteractionArchTest(unittest.TestCase): def test_basic(self) -> None: torch.manual_seed(0) @@ -129,7 +152,7 @@ def test_fx(self) -> None: deep_fm_dimension=5, ) gm = symbolic_trace(deepfm_nn) - FileCheck().check("KeyedJaggedTensor").check("cat").check("f2").run(gm.code) + FileCheck().check("KeyedJaggedTensor").check("f2").run(gm.code) features = torch.rand((B, num_dense_features)) sparse_features = KeyedJaggedTensor.from_offsets_sync( diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py index cfb5a3dff..e01976404 100644 --- a/torchrec/models/tests/test_dlrm.py +++ b/torchrec/models/tests/test_dlrm.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch @@ -12,6 +14,8 @@ from torch.testing import FileCheck # @manual from torchrec.datasets.utils import Batch from torchrec.fx import symbolic_trace +from torchrec.ir.serializer import JsonSerializer +from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules from torchrec.models.dlrm import ( choose, DenseArch, @@ -66,14 +70,14 @@ def test_basic(self) -> None: expected_values = torch.tensor( [ [ - [-0.7499, -1.2665, 1.0143], - [-0.7499, -1.2665, 1.0143], - [3.2276, 2.9643, -0.3816], + [-0.4518, -0.0242, 0.3637], + [-0.4518, -0.0242, 0.3637], + [0.2940, -0.2385, -0.0074], ], [ - [0.0082, 0.6241, -0.1119], - [0.0082, 0.6241, -0.1119], - [2.0722, -2.2734, -1.6307], + [-0.5452, -0.0231, -0.1907], + [-0.5452, -0.0231, -0.1907], + [-0.3530, -0.5551, -0.3342], ], ] ) @@ -396,7 +400,7 @@ def test_basic(self) -> None: ) self.assertEqual(logits.size(), (B, 1)) - expected_logits = torch.tensor([[0.5805], [0.5909]]) + expected_logits = torch.tensor([[-0.2593], [-0.2487]]) self.assertTrue( torch.allclose( logits, @@ -470,7 +474,7 @@ def test_fx(self) -> None: over_arch_layer_sizes=[5, 1], ) gm = symbolic_trace(sparse_nn) - FileCheck().check("KeyedJaggedTensor").check("cat").check("f2").run(gm.code) + FileCheck().check("KeyedJaggedTensor").check("f2").run(gm.code) features = torch.rand((B, dense_in_features)) sparse_features = KeyedJaggedTensor.from_offsets_sync( @@ -801,7 +805,7 @@ def test_basic(self) -> None: ) self.assertEqual(logits.size(), (B, 1)) - expected_logits = torch.tensor([[-0.0036], [-0.0260]]) + expected_logits = torch.tensor([[-0.4603], [-0.4639]]) self.assertTrue( torch.allclose( logits, @@ -1097,7 +1101,7 @@ def test_basic(self) -> None: ) self.assertEqual(logits.size(), (B, 1)) - expected_logits = torch.tensor([[1.5232], [0.1726]]) + expected_logits = torch.tensor([[0.0455], [0.0408]]) self.assertTrue( torch.allclose( logits, @@ -1172,3 +1176,110 @@ def test_basic(self) -> None: _, (_, logits, _) = dlrm(batch) self.assertEqual(logits.size(), (B,)) + + +class DLRMExampleTest(unittest.TestCase): + def test_basic(self) -> None: + B = 2 + D = 8 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + model = DLRM( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, 100)) + + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + + logits = model( + dense_features=features, + sparse_features=sparse_features, + ) + + self.assertEqual(logits.size(), (B, 1)) + + def test_export_serialization(self) -> None: + B = 2 + D = 8 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + model = DLRM( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, 100)) + + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + + logits = model( + dense_features=features, + sparse_features=sparse_features, + ) + + self.assertEqual(logits.size(), (B, 1)) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + + ep = torch.export.export( + model, + (features, sparse_features), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(features, sparse_features) + self.assertEqual(ep_output.size(), (B, 1)) + + unflatten_model = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_model, JsonSerializer) + deserialized_logits = deserialized_model(features, sparse_features) + + self.assertEqual(deserialized_logits.size(), (B, 1)) diff --git a/torchrec/modules/__init__.py b/torchrec/modules/__init__.py index 8bb52e5dc..c1852cee7 100644 --- a/torchrec/modules/__init__.py +++ b/torchrec/modules/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Common Modules The torchrec modules contain a collection of various modules. diff --git a/torchrec/modules/activation.py b/torchrec/modules/activation.py index 632f648d4..6a541948d 100644 --- a/torchrec/modules/activation.py +++ b/torchrec/modules/activation.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """ Activation Modules """ diff --git a/torchrec/modules/crossnet.py b/torchrec/modules/crossnet.py index 9c2a8d52e..43771be5d 100644 --- a/torchrec/modules/crossnet.py +++ b/torchrec/modules/crossnet.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + # Sphinx Documentation Text (for user-facing classes only) """ @@ -28,10 +30,10 @@ class CrossNet(torch.nn.Module): (NxN), such that the crossing effect can cover all bits on each layer. On each layer l, the tensor is transformed into: - .. math :: x_{l+1} = x_0 * (W_l \dot x_l + b_l) + x_l + .. math :: x_{l+1} = x_0 * (W_l \cdot x_l + b_l) + x_l where :math:`W_l` is a square matrix :math:`(NxN)`, :math:`*` means element-wise - multiplication, :math:`\dot` means matrix multiplication. + multiplication, :math:`\cdot` means matrix multiplication. Args: in_features (int): the dimension of the input. @@ -81,9 +83,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_l = x_0 for layer in range(self._num_layers): - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. xl_w = torch.matmul(self.kernels[layer], x_l) # (B, N, 1) - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. x_l = x_0 * (xl_w + self.bias[layer]) + x_l # (B, N, 1) return torch.squeeze(x_l, dim=2) @@ -92,15 +94,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class LowRankCrossNet(torch.nn.Module): r""" Low Rank Cross Net is a highly efficient cross net. Instead of using full rank cross - matrices (NxN) at each layer, it will use two kernels :math:`W (N * r)` and - :math:`V (r * N)`, where `r << N`, to simplify the matrix multiplication. + matrices (NxN) at each layer, it will use two kernels :math:`W (N x r)` and + :math:`V (r x N)`, where `r << N`, to simplify the matrix multiplication. On each layer l, the tensor is transformed into: - .. math:: x_{l+1} = x_0 * (W_l \dot (V_l \dot x_l) + b_l) + x_l + .. math:: x_{l+1} = x_0 * (W_l \cdot (V_l \cdot x_l) + b_l) + x_l where :math:`W_l` is either a vector, :math:`*` means element-wise multiplication, - and :math:`\dot` means matrix multiplication. + and :math:`\cdot` means matrix multiplication. NOTE: Rank `r` should be chosen smartly. Usually, we expect `r < N/2` to have @@ -110,8 +112,8 @@ class LowRankCrossNet(torch.nn.Module): Args: in_features (int): the dimension of the input. num_layers (int): the number of layers in the module. - low_rank (int): the rank setup of the cross matrix (default = 0). - Value must be always >= 0. + low_rank (int): the rank setup of the cross matrix (default = 1). + Value must be always >= 1. Example:: @@ -134,7 +136,7 @@ def __init__( self._num_layers = num_layers self._low_rank = low_rank - self.W_kernels: torch.nn.Module = torch.nn.ParameterList( + self.W_kernels: torch.nn.ParameterList = torch.nn.ParameterList( [ torch.nn.Parameter( torch.nn.init.xavier_normal_( @@ -144,7 +146,7 @@ def __init__( for i in range(self._num_layers) ] ) - self.V_kernels: torch.nn.Module = torch.nn.ParameterList( + self.V_kernels: torch.nn.ParameterList = torch.nn.ParameterList( [ torch.nn.Parameter( torch.nn.init.xavier_normal_( @@ -154,9 +156,9 @@ def __init__( for i in range(self._num_layers) ] ) - self.bias: torch.nn.Module = torch.nn.ParameterList( + self.bias: torch.nn.ParameterList = torch.nn.ParameterList( [ - torch.nn.Parameter(torch.nn.init.zeros_(torch.empty(in_features, 1))) + torch.nn.Parameter(torch.nn.init.zeros_(torch.empty(in_features))) for i in range(self._num_layers) ] ) @@ -170,22 +172,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: tensor with shape [batch_size, in_features]. """ - x_0 = input.unsqueeze(2) # (B, N, 1) + x_0 = input x_l = x_0 for layer in range(self._num_layers): - xl_w = torch.matmul( - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a - # function. - self.W_kernels[layer], - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a - # function. - torch.matmul(self.V_kernels[layer], x_l), - ) # (B, N, 1) - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. - x_l = x_0 * (xl_w + self.bias[layer]) + x_l # (B, N, 1) + x_l_v = torch.nn.functional.linear(x_l, self.V_kernels[layer]) + x_l_w = torch.nn.functional.linear(x_l_v, self.W_kernels[layer]) + x_l = x_0 * (x_l_w + self.bias[layer]) + x_l # (B, N) - return torch.squeeze(x_l, dim=2) # (B, N) + return x_l class VectorCrossNet(torch.nn.Module): @@ -257,12 +252,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: for layer in range(self._num_layers): xl_w = torch.tensordot( x_l, - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a - # function. + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.kernels[layer], dims=([1], [0]), ) # (B, 1, 1) - # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. x_l = torch.matmul(x_0, xl_w) + self.bias[layer] + x_l # (B, N, 1) return torch.squeeze(x_l, dim=2) # (B, N) @@ -286,7 +280,7 @@ class LowRankMixtureCrossNet(torch.nn.Module): and each :math:`expert_i` is defined as: - .. math:: expert_i = x_0 * (U_{li} \dot g(C_{li} \dot g(V_{li} \dot x_l)) + b_l) + .. math:: expert_i = x_0 * (U_{li} \cdot g(C_{li} \cdot g(V_{li} \cdot x_l)) + b_l) where :math:`U_{li} (N, r)`, :math:`C_{li} (r, r)` and :math:`V_{li} (r, N)` are low-rank matrices, :math:`*` means element-wise multiplication, :math:`x` means @@ -298,8 +292,8 @@ class LowRankMixtureCrossNet(torch.nn.Module): Args: in_features (int): the dimension of the input. num_layers (int): the number of layers in the module. - low_rank (int): the rank setup of the cross matrix (default = 0). - Value must be always >= 0 + low_rank (int): the rank setup of the cross matrix (default = 1). + Value must be always >= 1 activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]): the non-linear activation function, used in defining experts. Default is relu. @@ -413,21 +407,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: experts = [] for i in range(self._num_experts): expert = torch.matmul( - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.V_kernels[layer][i], x_l, ) # (B, r, 1) expert = torch.matmul( - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.C_kernels[layer][i], self._activation(expert), ) # (B, r, 1) expert = torch.matmul( - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.U_kernels[layer][i], self._activation(expert), ) # (B, N, 1) - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. expert = x_0 * (expert + self.bias[layer]) # (B, N, 1) experts.append(expert.squeeze(2)) # (B, N) experts = torch.stack(experts, 2) # (B, N, K) diff --git a/torchrec/modules/deepfm.py b/torchrec/modules/deepfm.py index 1792bfa45..b99b15e3b 100644 --- a/torchrec/modules/deepfm.py +++ b/torchrec/modules/deepfm.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """ Deep Factorization-Machine Modules diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 1a088e903..b665257a8 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -5,46 +5,64 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from dataclasses import dataclass, field from enum import Enum, unique +from functools import partial from math import sqrt -from typing import Dict, List, Optional +from typing import Callable, Dict, List, NamedTuple, Optional import torch from fbgemm_gpu.split_embedding_configs import SparseType -from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode +from fbgemm_gpu.split_table_batched_embeddings_ops_training import PoolingMode +from torchrec.types import DataType @unique class PoolingType(Enum): + """ + Pooling type for embedding table. + + Args: + SUM (str): sum pooling. + MEAN (str): mean pooling. + NONE (str): no pooling. + """ + SUM = "SUM" MEAN = "MEAN" NONE = "NONE" -@unique -class DataType(Enum): +# TODO - duplicated, move elsewhere to remove circular dependencies +class ShardingType(Enum): """ - Our fusion implementation supports only certain types of data - so it makes sense to retrict in a non-fused version as well. + Well-known sharding types, used by inter-module optimizations. """ - FP32 = "FP32" - FP16 = "FP16" - INT64 = "INT64" - INT32 = "INT32" - INT8 = "INT8" - INT4 = "INT4" - INT2 = "INT2" - - def __str__(self) -> str: - return self.value + # Replicated on all ranks + DATA_PARALLEL = "data_parallel" + # Placed on a single rank + TABLE_WISE = "table_wise" + # Placed on multiple ranks as different sharded tables + COLUMN_WISE = "column_wise" + # Range-split on the first dimension across all ranks + ROW_WISE = "row_wise" + # Row-wise on the same node and table-wise across nodes + # Useful when having multiple ranks per node + # and comms within a single node are more efficient than across nodes. + TABLE_ROW_WISE = "table_row_wise" + # Column-wise on the same node and table-wise across nodes + TABLE_COLUMN_WISE = "table_column_wise" DATA_TYPE_NUM_BITS: Dict[DataType, int] = { DataType.FP32: 32, DataType.FP16: 16, + DataType.BF16: 16, DataType.INT8: 8, + DataType.UINT8: 8, DataType.INT4: 4, DataType.INT2: 2, } @@ -55,12 +73,16 @@ def dtype_to_data_type(dtype: torch.dtype) -> DataType: return DataType.FP32 elif dtype == torch.float16 or dtype == torch.half: return DataType.FP16 + elif dtype == torch.bfloat16: + return DataType.BF16 elif dtype in {torch.int, torch.int32}: return DataType.INT32 elif dtype in {torch.long, torch.int64}: return DataType.INT64 - elif dtype in {torch.quint8, torch.qint8, torch.int8, torch.uint8}: + elif dtype in {torch.quint8, torch.qint8, torch.int8}: return DataType.INT8 + elif dtype == torch.uint8: + return DataType.UINT8 elif dtype == torch.quint4x2: return DataType.INT4 elif dtype == torch.quint2x4: @@ -69,21 +91,30 @@ def dtype_to_data_type(dtype: torch.dtype) -> DataType: raise Exception(f"Invalid data type {dtype}") -def pooling_type_to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: - if pooling_type == PoolingType.SUM: +def pooling_type_to_pooling_mode( + pooling_type: PoolingType, sharding_type: Optional[ShardingType] = None +) -> PoolingMode: + if pooling_type.value == PoolingType.SUM.value: return PoolingMode.SUM - elif pooling_type == PoolingType.MEAN: + elif pooling_type.value == PoolingType.MEAN.value: + if sharding_type is not None and sharding_type.value in [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ]: + # Mean pooling is not supported in TBE for TWRW/RW sharding. + # Pass 'SUM' as a workaround, and apply mean pooling as a callback in EBC. + return PoolingMode.SUM return PoolingMode.MEAN - elif pooling_type == PoolingType.NONE: + elif pooling_type.value == PoolingType.NONE.value: return PoolingMode.NONE else: raise Exception(f"Invalid pooling type {pooling_type}") def pooling_type_to_str(pooling_type: PoolingType) -> str: - if pooling_type == PoolingType.SUM: + if pooling_type.value == PoolingType.SUM.value: return "sum" - elif pooling_type == PoolingType.MEAN: + elif pooling_type.value == PoolingType.MEAN.value: return "mean" else: raise ValueError(f"Unsupported pooling type {pooling_type}") @@ -94,7 +125,9 @@ def data_type_to_sparse_type(data_type: DataType) -> SparseType: return SparseType.FP32 elif data_type == DataType.FP16: return SparseType.FP16 - elif data_type == DataType.INT8: + elif data_type == DataType.BF16: + return SparseType.BF16 + elif data_type == DataType.INT8 or data_type == DataType.UINT8: return SparseType.INT8 elif data_type == DataType.INT4: return SparseType.INT4 @@ -105,19 +138,23 @@ def data_type_to_sparse_type(data_type: DataType) -> SparseType: def data_type_to_dtype(data_type: DataType) -> torch.dtype: - if data_type == DataType.FP32: + if data_type.value == DataType.FP32.value: return torch.float32 - elif data_type == DataType.FP16: + elif data_type.value == DataType.FP16.value: return torch.float16 - elif data_type == DataType.INT64: + elif data_type.value == DataType.BF16.value: + return torch.bfloat16 + elif data_type.value == DataType.INT64.value: return torch.int64 - elif data_type == DataType.INT32: + elif data_type.value == DataType.INT32.value: return torch.int32 - elif data_type == DataType.INT8: + elif data_type.value == DataType.INT8.value: return torch.int8 - elif data_type == DataType.INT4: + elif data_type.value == DataType.UINT8.value: + return torch.uint8 + elif data_type.value == DataType.INT4.value: return torch.quint4x2 - elif data_type == DataType.INT2: + elif data_type.value == DataType.INT2.value: return torch.quint2x4 else: raise ValueError(f"DataType {data_type} cannot be converted to dtype") @@ -125,6 +162,23 @@ def data_type_to_dtype(data_type: DataType) -> torch.dtype: @dataclass class BaseEmbeddingConfig: + """ + Base class for embedding configs. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): embedding dimension. + name (str): name of the embedding table. + data_type (DataType): data type of the embedding table. + feature_names (List[str]): list of feature names. + weight_init_max (Optional[float]): max value for weight initialization. + weight_init_min (Optional[float]): min value for weight initialization. + num_embeddings_post_pruning (Optional[int]): number of embeddings after pruning for inference. + If None, no pruning is applied. + init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]): init function for embedding weights. + need_pos (bool): whether table is position weighted. + """ + num_embeddings: int embedding_dim: int name: str = "" @@ -132,10 +186,16 @@ class BaseEmbeddingConfig: feature_names: List[str] = field(default_factory=list) weight_init_max: Optional[float] = None weight_init_min: Optional[float] = None + num_embeddings_post_pruning: Optional[int] = None + + init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None # when the position_weighted feature is in this table config, # enable this flag to support rw_sharding need_pos: bool = False + # handle the special case + input_dim: Optional[int] = None + def get_weight_init_max(self) -> float: if self.weight_init_max is None: return sqrt(1 / self.num_embeddings) @@ -151,6 +211,14 @@ def get_weight_init_min(self) -> float: def num_features(self) -> int: return len(self.feature_names) + def __post_init__(self) -> None: + if self.init_fn is None: + self.init_fn = partial( + torch.nn.init.uniform_, + a=self.get_weight_init_min(), + b=self.get_weight_init_max(), + ) + # this class will be deprecated after migration # and all the following code in sharding itself @@ -165,9 +233,28 @@ class EmbeddingTableConfig(BaseEmbeddingConfig): @dataclass class EmbeddingBagConfig(BaseEmbeddingConfig): + """ + EmbeddingBagConfig is a dataclass that represents a single embedding table, + where outputs are meant to be pooled. + + Args: + pooling (PoolingType): pooling type. + """ + pooling: PoolingType = PoolingType.SUM @dataclass class EmbeddingConfig(BaseEmbeddingConfig): + """ + EmbeddingConfig is a dataclass that represents a single embedding table. + + """ + pass + + +class QuantConfig(NamedTuple): + activation: torch.quantization.PlaceholderObserver + weight: torch.quantization.PlaceholderObserver + per_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 16d7a6f35..d110fd57f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -5,8 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -17,6 +19,34 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt + + +@torch.fx.wrap +def reorder_inverse_indices( + inverse_indices: Optional[Tuple[List[str], torch.Tensor]], + feature_names: List[str], +) -> torch.Tensor: + if inverse_indices is None: + return torch.empty(0) + index_per_name = {name: i for i, name in enumerate(inverse_indices[0])} + index = torch.tensor( + [index_per_name[name.split("@")[0]] for name in feature_names], + device=inverse_indices[1].device, + ) + return torch.index_select(inverse_indices[1], 0, index) + + +@torch.fx.wrap +def process_pooled_embeddings( + pooled_embeddings: List[torch.Tensor], + inverse_indices: torch.Tensor, +) -> torch.Tensor: + if inverse_indices.numel() > 0: + pooled_embeddings = torch.ops.fbgemm.group_index_select_dim0( + pooled_embeddings, list(torch.unbind(inverse_indices)) + ) + return torch.cat(pooled_embeddings, dim=1) class EmbeddingBagCollectionInterface(abc.ABC, nn.Module): @@ -68,18 +98,30 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface): """ EmbeddingBagCollection represents a collection of pooled embeddings (`EmbeddingBags`). - It processes sparse data in the form of `KeyedJaggedTensor` with values of the form - [F X B X L] where: + NOTE: + EmbeddingBagCollection is an unsharded module and is not performance optimized. + For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection. + + + It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape + `(F, B, L[f][i])` where: - * F: features (keys) - * B: batch size - * L: length of sparse features (jagged) + * `F`: number of features (keys) + * `B`: batch size + * `L[f][i]`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged) - and outputs a `KeyedTensor` with values of the form [B * (F * D)] where: + and outputs a `KeyedTensor` with values with shape `(B, D)` where: - * F: features (keys) - * D: each feature's (key's) embedding dimension - * B: batch size + * `B`: batch size + * `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])` + + Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L[f][i]` sparse lengths + such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows: + `KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`. + + Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding + produced by reducing the embeddings of each of the values in `J[f][i]` + using the `EmbeddingBag` `emb[f]`'s mode (default is the mean). Args: tables (List[EmbeddingBagConfig]): list of embedding tables. @@ -97,28 +139,35 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface): ebc = EmbeddingBagCollection(tables=[table_0, table_1]) - # 0 1 2 <-- batch - # "f1" [0,1] None [2] - # "f2" [3] [4] [5,6,7] + # i = 0 i = 1 i = 2 <-- batch indices + # "f1" [0,1] None [2] + # "f2" [3] [4] [5,6,7] # ^ - # feature + # features features = KeyedJaggedTensor( keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.tensor([0, 1, 2, # feature 'f1' + 3, 4, 5, 6, 7]), # feature 'f2' + # i = 1 i = 2 i = 3 <--- batch indices + offsets=torch.tensor([ + 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3] + 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8] ) pooled_embeddings = ebc(features) print(pooled_embeddings.values()) - tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], - [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], - [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], + tensor([ + # f1 pooled embeddings f2 pooled embeddings + # from bags (dim. 3) from bags (dim. 4) + [-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # i = 0 + [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # i = 1 + [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # i = 2 grad_fn=) print(pooled_embeddings.keys()) ['f1', 'f2'] print(pooled_embeddings.offset_per_key()) - tensor([0, 3, 7]) + tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7). """ def __init__( @@ -133,9 +182,6 @@ def __init__( self.embedding_bags: nn.ModuleDict = nn.ModuleDict() self._embedding_bag_configs = tables self._lengths_per_embedding: List[int] = [] - self._device: torch.device = ( - device if device is not None else torch.device("cpu") - ) table_names = set() for embedding_config in tables: @@ -155,30 +201,46 @@ def __init__( include_last_offset=True, dtype=dtype, ) + if device is None: + device = self.embedding_bags[embedding_config.name].weight.device + if not embedding_config.feature_names: embedding_config.feature_names = [embedding_config.name] self._lengths_per_embedding.extend( len(embedding_config.feature_names) * [embedding_config.embedding_dim] ) + self._device: torch.device = device or torch.device("cpu") self._embedding_names: List[str] = [ embedding for embeddings in get_embedding_names_by_table(tables) for embedding in embeddings ] self._feature_names: List[List[str]] = [table.feature_names for table in tables] + self.reset_parameters() - def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + def forward( + self, + features: KeyedJaggedTensor, # can also take TensorDict as input + ) -> KeyedTensor: """ - Args: - features (KeyedJaggedTensor): KJT of form [F X B X L]. + Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` + and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. + Args: + features (KeyedJaggedTensor): Input KJT Returns: KeyedTensor """ - + flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) + for names in self._feature_names: + flat_feature_names.extend(names) + inverse_indices = reorder_inverse_indices( + inverse_indices=features.inverse_indices_or_none(), + feature_names=flat_feature_names, + ) pooled_embeddings: List[torch.Tensor] = [] - feature_dict = features.to_dict() for i, embedding_bag in enumerate(self.embedding_bags.values()): for feature_name in self._feature_names[i]: @@ -189,23 +251,53 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: per_sample_weights=f.weights() if self._is_weighted else None, ).float() pooled_embeddings.append(res) - data = torch.cat(pooled_embeddings, dim=1) return KeyedTensor( keys=self._embedding_names, - values=data, + values=process_pooled_embeddings( + pooled_embeddings=pooled_embeddings, + inverse_indices=inverse_indices, + ), length_per_key=self._lengths_per_embedding, ) def is_weighted(self) -> bool: + """ + Returns: + bool: Whether the EmbeddingBagCollection is weighted. + """ return self._is_weighted def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: + """ + Returns: + List[EmbeddingBagConfig]: The embedding bag configs. + """ return self._embedding_bag_configs @property def device(self) -> torch.device: + """ + Returns: + torch.device: The compute device. + """ return self._device + def reset_parameters(self) -> None: + """ + Reset the parameters of the EmbeddingBagCollection. Parameter values + are intiialized based on the `init_fn` of each EmbeddingBagConfig if it exists. + """ + if (isinstance(self.device, torch.device) and self.device.type == "meta") or ( + isinstance(self.device, str) and self.device == "meta" + ): + return + # Initialize embedding bags weights with init_fn + for table_config in self._embedding_bag_configs: + assert table_config.init_fn is not None + param = self.embedding_bags[f"{table_config.name}"].weight + # pyre-ignore + table_config.init_fn(param) + class EmbeddingCollectionInterface(abc.ABC, nn.Module): """ @@ -242,20 +334,23 @@ class EmbeddingCollection(EmbeddingCollectionInterface): """ EmbeddingCollection represents a collection of non-pooled embeddings. - It processes sparse data in the form of `KeyedJaggedTensor` of the form [F X B X L] - where: + NOTE: + EmbeddingCollection is an unsharded module and is not performance optimized. + For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection. - * F: features (keys) - * B: batch size - * L: length of sparse features (variable) + It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape + `(F, B, L[f][i])` where: - and outputs `Dict[feature (key), JaggedTensor]`. - Each `JaggedTensor` contains values of the form (B * L) X D - where: + * `F`: number of features (keys) + * `B`: batch size + * `L[f][i]`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged) - * B: batch size - * L: length of sparse features (jagged) - * D: each feature's (key's) embedding dimension and lengths are of the form L + and outputs a `result` of type `Dict[Feature, JaggedTensor]`, + where `result[f]` is a `JaggedTensor` with shape `(EB[f], D[f])` where: + + * `EB[f]`: a "expanded batch size" for feature `f` equal to the sum of the lengths of its bag values, + that is, `sum([len(J[f][i]) for i in range(B)])`. + * `D[f]`: is the embedding dimension of feature `f`. Args: tables (List[EmbeddingConfig]): list of embedding tables. @@ -281,16 +376,29 @@ class EmbeddingCollection(EmbeddingCollectionInterface): features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.tensor([0, 1, 2, # feature 'f1' + 3, 4, 5, 6, 7]), # feature 'f2' + # i = 1 i = 2 i = 3 <--- batch indices + offsets=torch.tensor([ + 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3] + 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8] ) + feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) - tensor([[-0.2050, 0.5478, 0.6054], - [ 0.7352, 0.3210, -3.0399], - [ 0.1279, -0.1756, -0.4130], - [ 0.7519, -0.4341, -0.0499], - [ 0.9329, -1.0697, -0.8095]], grad_fn=) + tensor([ + # embedding for value 3 in f2 bag values[3:4]: + [-0.2050, 0.5478, 0.6054], + + # embedding for value 4 in f2 bag values[4:5]: + [ 0.7352, 0.3210, -3.0399], + + # embedding for values 5, 6, 7 in f2 bag values[5:8]: + [ 0.1279, -0.1756, -0.4130], + [ 0.7519, -0.4341, -0.0499], + [ 0.9329, -1.0697, -0.8095], + + ], grad_fn=) """ def __init__( # noqa C901 @@ -320,6 +428,8 @@ def __init__( # noqa C901 if self._embedding_dim != config.embedding_dim: raise ValueError( "All tables in a EmbeddingCollection are required to have same embedding dimension." + + f" Violating case: {config.name}'s embedding_dim {config.embedding_dim} !=" + + f" {self._embedding_dim}" ) dtype = ( torch.float32 if config.data_type == DataType.FP32 else torch.float16 @@ -330,6 +440,9 @@ def __init__( # noqa C901 device=device, dtype=dtype, ) + if config.init_fn is not None: + config.init_fn(self.embeddings[config.name].weight) + if not config.feature_names: config.feature_names = [config.name] @@ -340,9 +453,12 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: KeyedJaggedTensor, # can also take TensorDict as input ) -> Dict[str, JaggedTensor]: """ + Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` + and returns a `Dict[str, JaggedTensor]`, which is the result of the individual embeddings for each feature. + Args: features (KeyedJaggedTensor): KJT of form [F X B X L]. @@ -350,6 +466,7 @@ def forward( Dict[str, JaggedTensor] """ + features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()): @@ -369,17 +486,54 @@ def forward( return feature_embeddings def need_indices(self) -> bool: + """ + Returns: + bool: Whether the EmbeddingCollection needs indices. + """ return self._need_indices def embedding_dim(self) -> int: + """ + Returns: + int: The embedding dimension. + """ return self._embedding_dim def embedding_configs(self) -> List[EmbeddingConfig]: + """ + Returns: + List[EmbeddingConfig]: The embedding configs. + """ return self._embedding_configs def embedding_names_by_table(self) -> List[List[str]]: + """ + Returns: + List[List[str]]: The embedding names by table. + """ return self._embedding_names_by_table @property def device(self) -> torch.device: + """ + Returns: + torch.device: The compute device. + """ return self._device + + def reset_parameters(self) -> None: + """ + Reset the parameters of the EmbeddingCollection. Parameter values + are intiialized based on the `init_fn` of each EmbeddingConfig if it exists. + """ + + if (isinstance(self.device, torch.device) and self.device.type == "meta") or ( + isinstance(self.device, str) and self.device == "meta" + ): + return + # Initialize embedding bags weights with init_fn + for table_config in self._embedding_configs: + assert table_config.init_fn is not None + param = self.embeddings[f"{table_config.name}"].weight + # pyre-ignore + table_config.init_fn(param) diff --git a/torchrec/modules/embedding_tower.py b/torchrec/modules/embedding_tower.py index 9cfd8667a..d9a1d182c 100644 --- a/torchrec/modules/embedding_tower.py +++ b/torchrec/modules/embedding_tower.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import List, Optional, Tuple diff --git a/torchrec/modules/feature_processor.py b/torchrec/modules/feature_processor.py index 91e144cee..79822b092 100644 --- a/torchrec/modules/feature_processor.py +++ b/torchrec/modules/feature_processor.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -31,6 +33,23 @@ def forward( pass +@torch.fx.wrap +def position_weighted_module_update_features( + features: Dict[str, JaggedTensor], + weighted_features: Dict[str, JaggedTensor], +) -> Dict[str, JaggedTensor]: + features.update(weighted_features) + return features + + +@torch.jit.script_if_tracing +@torch.fx.wrap +def offsets_to_range_traceble( + offsets: torch.Tensor, values: torch.Tensor +) -> torch.Tensor: + return torch.ops.fbgemm.offsets_range(offsets.long(), torch.numel(values)) + + # Will be deprecated soon, please use PositionWeightedProcessor, see full doc below class PositionWeightedModule(BaseFeatureProcessor): """ @@ -46,12 +65,21 @@ class PositionWeightedModule(BaseFeatureProcessor): def __init__( self, max_feature_lengths: Dict[str, int], + device: Optional[torch.device] = None, ) -> None: super().__init__() self.max_feature_lengths = max_feature_lengths self.position_weights: nn.ParameterDict = nn.ParameterDict() for key, length in max_feature_lengths.items(): - self.position_weights[key] = nn.Parameter(torch.empty([length]).fill_(1.0)) + self.position_weights[key] = nn.Parameter( + torch.empty([length], device=device) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + for key, _length in self.max_feature_lengths.items(): + self.position_weights[key].fill_(1.0) def forward( self, @@ -67,17 +95,17 @@ def forward( """ weighted_features: Dict[str, JaggedTensor] = {} - for key, pos_weight in self.position_weights.items(): - seq = torch.ops.fbgemm.offsets_range( - features[key].offsets().long(), torch.numel(features[key].values()) + for key, position_weight in self.position_weights.items(): + seq = offsets_to_range_traceble( + features[key].offsets(), features[key].values() ) weighted_features[key] = JaggedTensor( values=features[key].values(), lengths=features[key].lengths(), offsets=features[key].offsets(), - weights=torch.gather(pos_weight, dim=0, index=seq), + weights=torch.gather(position_weight, dim=0, index=seq), ) - return weighted_features + return position_weighted_module_update_features(features, weighted_features) class BaseGroupedFeatureProcessor(nn.Module): @@ -92,12 +120,6 @@ def forward( ) -> KeyedJaggedTensor: pass - def sparse_grad_parameter_names( - self, destination: Optional[List[str]] = None, prefix: str = "" - ) -> List[str]: - destination = [] if destination is None else destination - return destination - class PositionWeightedProcessor(BaseGroupedFeatureProcessor): """ @@ -198,6 +220,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: length_per_key = features.length_per_key() weights = features.weights_or_none() batch_size = features.stride() + vbe = features.variable_stride_per_key() + stride_per_key_per_rank = features.stride_per_key_per_rank() has_fp_id_list_feature = False has_normal_id_list_feature = False @@ -243,6 +267,9 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: length_per_key=length_per_key, offset_per_key=features.offset_per_key(), index_per_key=features._key_indices(), + stride_per_key_per_rank=( + stride_per_key_per_rank if vbe else None + ), ) # for unsharded or sharded non-pipeling else: @@ -252,6 +279,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: processed_features_lengths: List[torch.Tensor] = [] processed_features_values: List[torch.Tensor] = [] processed_features_weights: List[torch.Tensor] = [] + processed_features_batch_sizes = [] for feature_index, feature_name in enumerate(feature_names): if feature_name in self.max_feature_lengths: feature_value = feature_values[feature_index] @@ -265,11 +293,18 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: processed_features_lengths.append(feature_length) processed_features_values.append(feature_value) processed_features_weights.append(processed_weight) + if vbe: + processed_features_batch_sizes.append( + stride_per_key_per_rank[feature_index] + ) fp_features = KeyedJaggedTensor.from_lengths_sync( keys=processed_features_names, values=torch.cat(processed_features_values), lengths=torch.cat(processed_features_lengths), weights=torch.cat(processed_features_weights), + stride_per_key_per_rank=( + processed_features_batch_sizes if vbe else None + ), ) return fp_features # normal id_list feature @@ -297,9 +332,3 @@ def state_dict( param if keep_vars else param.detach() ) return destination - - def sparse_grad_parameter_names( - self, destination: Optional[List[str]] = None, prefix: str = "" - ) -> List[str]: - destination = [] if destination is None else destination - return destination diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py new file mode 100644 index 000000000..24427cec3 --- /dev/null +++ b/torchrec/modules/feature_processor_.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import abc +from typing import Dict, List, Optional + +import torch + +from torch import nn +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.types import CopyMixIn + + +class FeatureProcessor(nn.Module): + """ + Abstract base class for feature processor. + + Args: + features (JaggedTensor]): feature representation + + Returns: + JaggedTensor: modified JT + + + Example:: + jt = JaggedTensor(...) + fp = FeatureProcessor(...) + fp_jt = FeatureProcessor(fp) + """ + + @abc.abstractmethod + def forward( + self, + features: JaggedTensor, + ) -> JaggedTensor: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + JaggedTensor: modified JT + """ + pass + + +class PositionWeightedModule(FeatureProcessor): + """ + Adds position weights to id list features. + + Args: + `max_length`, a.k.a truncation size, specifies the maximum number of ids + each sample has. For each feature, its position weight parameter size is + `max_length`. + """ + + def __init__( + self, max_feature_length: int, device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.position_weight = nn.Parameter( + torch.empty([max_feature_length], device=device), + requires_grad=True, + ) + + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + self.position_weight.fill_(1.0) + + def forward( + self, + features: JaggedTensor, + ) -> JaggedTensor: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + JaggedTensor: same as input features with `weights` field being populated. + """ + + seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + weighted_features = JaggedTensor( + values=features.values(), + lengths=features.lengths(), + offsets=features.offsets(), + weights=torch.gather(self.position_weight, dim=0, index=seq), + ) + return weighted_features + + +class FeatureProcessorsCollection(nn.Module): + """ + Abstract base class for feature processor. + + Args: + features (KeyedJaggedTensor]): feature representation + + Returns: + KeyedJaggedTensor: modified KJT + + + Example:: + kjt = JaggedTensor(...) + grouped_fp = FeatureProcessorsCollection(...) + fp_kjt = grouped_fp(kjt) + """ + + @abc.abstractmethod + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedJaggedTensor: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + JaggedTensor: modified JT + """ + pass + + +@torch.fx.wrap +def get_weights_list( + cat_seq: torch.Tensor, + features: KeyedJaggedTensor, + position_weights: Dict[str, nn.Parameter], +) -> Optional[torch.Tensor]: + weights_list = [] + seqs = torch.split(cat_seq, features.length_per_key()) + for key, seq in zip(features.keys(), seqs): + if key in position_weights.keys(): + weights_list.append(torch.gather(position_weights[key], dim=0, index=seq)) + else: + weights_list.append( + torch.ones(seq.shape[0], device=features.values().device) + ) + return torch.cat(weights_list) if weights_list else features.weights_or_none() + + +@torch.fx.wrap +def get_stride_per_key_per_rank(kjt: KeyedJaggedTensor) -> Optional[List[List[int]]]: + if not kjt.variable_stride_per_key(): + return None + return kjt.stride_per_key_per_rank() + + +class PositionWeightedModuleCollection(FeatureProcessorsCollection, CopyMixIn): + def __init__( + self, max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.max_feature_lengths = max_feature_lengths + for length in self.max_feature_lengths.values(): + if length <= 0: + raise + + self.position_weights: nn.ParameterDict = nn.ParameterDict() + # needed since nn.ParameterDict isn't torchscriptable (get_items) + self.position_weights_dict: Dict[str, nn.Parameter] = {} + + for key, length in max_feature_lengths.items(): + self.position_weights[key] = nn.Parameter( + torch.empty([length], device=device) + ) + + self.position_weights_dict[key] = self.position_weights[key] + + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + for key, _length in self.max_feature_lengths.items(): + self.position_weights[key].fill_(1.0) + # Re-assign python dict to param dict in case of re-materialization + self.position_weights_dict[key] = self.position_weights[key] + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=get_weights_list(cat_seq, features, self.position_weights_dict), + lengths=features.lengths(), + offsets=features.offsets(), + stride=features.stride(), + length_per_key=features.length_per_key(), + stride_per_key_per_rank=get_stride_per_key_per_rank(features), + ) + + def copy(self, device: torch.device) -> nn.Module: + self.position_weights = self.position_weights.to(device=device) + for key in self.position_weights.keys(): + self.position_weights_dict[key] = self.position_weights[key] + + return self + + # Override to make sure position_weights and position_weights_dict are in sync + # pyre-ignore [2] + def _apply(self, *args, **kwargs) -> nn.Module: + super()._apply(*args, **kwargs) + for k, param in self.position_weights.items(): + self.position_weights_dict[k] = param + + return self diff --git a/torchrec/modules/fp_embedding_modules.py b/torchrec/modules/fp_embedding_modules.py new file mode 100644 index 000000000..2f1a3abdb --- /dev/null +++ b/torchrec/modules/fp_embedding_modules.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, List, Set, Tuple, Union + +import torch +import torch.nn as nn +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorsCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +@torch.fx.wrap +def apply_feature_processors_to_kjt( + features: KeyedJaggedTensor, + feature_processors: Dict[str, nn.Module], +) -> KeyedJaggedTensor: + + processed_weights = [] + features_dict = features.to_dict() + + for key in features.keys(): + jt = features_dict[key] + if key in feature_processors: + fp_jt = feature_processors[key](jt) + processed_weights.append(fp_jt.weights()) + else: + processed_weights.append( + torch.ones(jt.values().shape[0], device=jt.values().device), + ) + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=( + torch.cat(processed_weights) + if processed_weights + else features.weights_or_none() + ), + lengths=features.lengths(), + offsets=features._offsets, + stride=features._stride, + length_per_key=features._length_per_key, + offset_per_key=features._offset_per_key, + index_per_key=features._index_per_key, + ) + + +class FeatureProcessorDictWrapper(FeatureProcessorsCollection): + def __init__(self, feature_processors: nn.ModuleDict) -> None: + super().__init__() + self._feature_processors = feature_processors + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + return apply_feature_processors_to_kjt(features, self._feature_processors) + + +class FeatureProcessedEmbeddingBagCollection(nn.Module): + """ + FeatureProcessedEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of feature processor modules. + The inputs into the FP-EBC will first be modified by the feature processor before being passed into the embedding bag collection. + + For details of input and output types, see EmbeddingBagCollection + + + Args: + embedding_bag_collection (EmbeddingBagCollection): ebc module + feature_processors (Dict[str, FeatureProcessor]): feature processors + Example:: + fp_ebc = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection(...), + { + "feature_1": FeatureProcessorModule(...), + "feature_2": FeatureProcessorModule2(...), + } + ) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + >>> fp_ebc(features).to_dict() + { + "feature_1": torch.Tensor(...) + "feature_2": torch.Tensor(...) + } + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + feature_processors: Union[ + Dict[str, FeatureProcessor], FeatureProcessorsCollection + ], + ) -> None: + super().__init__() + self._embedding_bag_collection = embedding_bag_collection + self._feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection] + + if isinstance(feature_processors, FeatureProcessorsCollection): + self._feature_processors = feature_processors + else: + self._feature_processors = nn.ModuleDict(feature_processors) + + assert set( + sum( + [ + config.feature_names + for config in self._embedding_bag_collection.embedding_bag_configs() + ], + [], + ) + ) == set( + feature_processors.keys() + ), "Passed in feature processors do not match feature names of embedding bag" + + assert ( + embedding_bag_collection.is_weighted() + ), "EmbeddingBagCollection must accept weighted inputs for feature processor" + + feature_names_set: Set[str] = set() + for table_config in self._embedding_bag_collection.embedding_bag_configs(): + feature_names_set.update(table_config.feature_names) + self._feature_names: List[str] = list(feature_names_set) + + def split( + self, + ) -> Tuple[FeatureProcessorsCollection, EmbeddingBagCollection]: + if isinstance(self._feature_processors, nn.ModuleDict): + return ( + FeatureProcessorDictWrapper(self._feature_processors), + self._embedding_bag_collection, + ) + else: + assert isinstance(self._feature_processors, FeatureProcessorsCollection) + return self._feature_processors, self._embedding_bag_collection + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + """ + Args: + features (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + KeyedTensor + """ + values = [] + lengths = [] + weights = [] + + if isinstance(self._feature_processors, FeatureProcessorsCollection): + fp_features = self._feature_processors(features) + else: + features_dict = features.to_dict() + for key in self._feature_names: + jt = self._feature_processors[key](features_dict[key]) + values.append(jt.values()) + lengths.append(jt.lengths()) + weights.append(jt.weights()) + + fp_features = KeyedJaggedTensor( + keys=self._feature_names, + values=torch.cat(values), + lengths=torch.cat(lengths), + weights=torch.cat(weights), + ) + + return self._embedding_bag_collection(fp_features) diff --git a/torchrec/modules/fused_embedding_modules.py b/torchrec/modules/fused_embedding_modules.py index 6fac21f80..064fdea36 100644 --- a/torchrec/modules/fused_embedding_modules.py +++ b/torchrec/modules/fused_embedding_modules.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 import copy @@ -16,7 +18,7 @@ import torch.nn as nn import torchrec.optim as trec_optim from fbgemm_gpu.split_embedding_configs import EmbOptimType -from fbgemm_gpu.split_table_batched_embeddings_ops import ( +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, EmbeddingLocation, SplitTableBatchedEmbeddingBagsCodegen, @@ -66,7 +68,7 @@ def __init__( # noqa C901 state: Dict[Any, Any] = {} param_group: Dict[str, Any] = { "params": [], - "lr": emb_module.optimizer_args.learning_rate, + "lr": emb_module.get_learning_rate(), } params: Dict[str, torch.Tensor] = {} @@ -197,7 +199,7 @@ def _init_parameters(self) -> None: assert len(self._num_embeddings) == len( self._emb_module.split_embedding_weights() ) - for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + for rows, emb_dim, weight_init_min, weight_init_max, param in zip( self._rows, self._cols, self._weight_init_mins, @@ -313,7 +315,7 @@ class FusedEmbeddingBagCollection( name="t2", embedding_dim=8, num_embeddings=10, feature_names=["f2"] ) - ebc = FusedEmbeddingBagCollection(tables=[table_0, table_1], optimizer=torch.optim.SGD, optimizer_kwargs={"lr": .01}) + ebc = FusedEmbeddingBagCollection(tables=[table_0, table_1], optimizer_type=torch.optim.SGD, optimizer_kwargs={"lr": .01}) # 0 1 2 <-- batch # "f1" [0,1] None [2] @@ -354,6 +356,9 @@ def __init__( self._optimizer_type = optimizer_type self._optimizer_kwargs = optimizer_kwargs + self._device: torch.device = ( + device if device is not None else torch.device("cpu") + ) emb_optim_and_kwargs = convert_optimizer_type_and_kwargs( optimizer_type, optimizer_kwargs @@ -437,7 +442,10 @@ def __init__( self._key_to_tables.items(), self._emb_modules ): for embedding_config, weight in zip( - tables, emb_module.split_embedding_weights() + tables, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + emb_module.split_embedding_weights(), + # torch._tensor.Tensor]` is not a function. ): self.embedding_bags[embedding_config.name] = torch.nn.Module() self.embedding_bags[embedding_config.name].register_parameter( @@ -498,6 +506,10 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: def _get_name(self) -> str: return "FusedEmbeddingBagCollection" + @property + def device(self) -> torch.device: + return self._device + def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: return self._embedding_bag_configs @@ -641,6 +653,8 @@ def __init__( elif self._embedding_dim != table.embedding_dim: raise ValueError( "All tables in a EmbeddingCollection are required to have same embedding dimension." + + f" Violating case: {table}'s embedding_dim {table.embedding_dim} !=" + + f" {self._embedding_dim}" ) for feature in table.feature_names: if feature in seen_features: @@ -683,7 +697,10 @@ def __init__( self._key_to_tables.items(), self._emb_modules ): for embedding_config, weight in zip( - tables, emb_module.split_embedding_weights() + tables, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + emb_module.split_embedding_weights(), + # torch._tensor.Tensor]` is not a function. ): self.embeddings[embedding_config.name] = torch.nn.Module() self.embeddings[embedding_config.name].register_parameter( diff --git a/torchrec/modules/itep_embedding_modules.py b/torchrec/modules/itep_embedding_modules.py new file mode 100644 index 000000000..e6fd15216 --- /dev/null +++ b/torchrec/modules/itep_embedding_modules.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import List + +import torch +import torch.nn as nn + +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig + +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.itep_modules import GenericITEPModule + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class ITEPEmbeddingBagCollection(nn.Module): + """ + ITEPEmbeddingBagCollection represents a EmbeddingBagCollection module and an In-Training Embedding Pruning (ITEP) module. + The inputs into the ITEP-EBC will first be modified by the ITEP module before being passed into the embedding bag collection. + Args: + embedding_bag_collection (EmbeddingBagCollection): The EmbeddingBagCollection module to lookup embeddings. + itep_module (GenericITEPModule): A single ITEP module that modifies the input features. + Example: + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=ebc, + itep_module=itep_module + ) + Note: + The forward method modifies the input features using the ITEP module before passing them to the EmbeddingBagCollection. + It also increments an internal iteration counter each time it is called. + For details of input and output types, see EmbeddingBagCollection. + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + itep_module: GenericITEPModule, + ) -> None: + super().__init__() + self._embedding_bag_collection = embedding_bag_collection + self._itep_module = itep_module + # Iteration counter for ITEP. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", + torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")), + ) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + """ + Forward pass for the ITEPEmbeddingBagCollection module. + The input features are first passed through the ITEP module, which modifies them. + The modified features are then passed to the EmbeddingBagCollection to get the pooled embeddings. + The internal iteration counter is incremented at each call. + Args: + features (KeyedJaggedTensor): The input features for the embedding lookup. + Returns: + KeyedTensor: The pooled embeddings from the EmbeddingBagCollection. + Note: + The iteration counter is incremented after each forward pass to keep track of the number of iterations. + """ + + # We need to explicitly move iter to CPU since it might be moved to GPU + # after __init__. This should be done once. + self._iter = self._iter.cpu() + + features = self._itep_module(features, self._iter.item()) + pooled_embeddings = self._embedding_bag_collection(features) + self._iter += 1 + + return pooled_embeddings + + def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: + return self._embedding_bag_collection.embedding_bag_configs() + + +class ITEPEmbeddingCollection(nn.Module): + """ + ITEPEmbeddingCollection represents a non-pooled EmbeddingCollection module and an In-Training Embedding Pruning (ITEP) module. + The inputs into the ITEP-EC will first be modified by the ITEP module before being passed into the embedding collection. + Args: + embedding_collection (EmbeddingCollection): The EmbeddingCollection module to lookup embeddings. + itep_module (GenericITEPModule): A single ITEP module that modifies the input features + Example: + itep_ebc = ITEPEmbeddingCollection( + embedding_collection=ec, + itep_module=itep_module + ) + Note: + The forward method modifies the input features using the ITEP module before passing them to the EmbeddingCollection. + It also increments an internal iteration counter each time it is called. + For details of input and output types, see EmbeddingCollection. + """ + + def __init__( + self, + embedding_collection: EmbeddingCollection, + itep_module: GenericITEPModule, + ) -> None: + super().__init__() + self._embedding_collection = embedding_collection + self._itep_module = itep_module + # Iteration counter for ITEP. Pinning on CPU because used for condition checking and checkpointing. + self.register_buffer( + "_iter", + torch.tensor(0, dtype=torch.int64, device=torch.device("cpu")), + ) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + """ + Forward pass for the ITEPEmbeddingCollection module. + The input features are first passed through the ITEP module, which modifies them. + The modified features are then passed to the EmbeddingCollection to get the non-pooled embeddings. + The internal iteration counter is incremented at each call. + Args: + features (KeyedJaggedTensor): The input features for the embedding lookup. + Returns: + KeyedTensor: The non-pooled embeddings from the EmbeddingCollection. + Note: + The iteration counter is incremented after each forward pass to keep track of the number of iterations. + """ + + # We need to explicitly move iter to CPU since it might be moved to GPU + # after __init__. This should be done once. + self._iter = self._iter.cpu() + + features = self._itep_module(features, self._iter.item()) + embeddings = self._embedding_collection(features) + self._iter += 1 + + return embeddings + + def embedding_bag_configs(self) -> List[EmbeddingConfig]: + return self._embedding_collection.embedding_configs() diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py new file mode 100644 index 000000000..8ffd4dbc0 --- /dev/null +++ b/torchrec/modules/itep_modules.py @@ -0,0 +1,800 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import logging +import math +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import distributed as dist, nn +from torch.distributed._shard.metadata import ShardMetadata +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType +from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata +from torchrec.modules.embedding_modules import reorder_inverse_indices +from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor + +try: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:intraining_embedding_pruning_gpu" + ) +except OSError: + pass + +logger: logging.Logger = logging.getLogger(__name__) + + +class GenericITEPModule(nn.Module): + """ + A generic module for applying In-Training Embedding Pruning (ITEP). + This module can be hooked into the forward() of `EmbeddingBagCollection`. + It will prune the embedding tables during training by applying a remapping transform + to the embedding lookup indices. + + Args: + table_name_to_unpruned_hash_sizes (Dict[str, int]): Map of table name to + unpruned hash size. + lookups (Optional[List[nn.Module]]): List of lookups in the EBC. Defaults to + `None`. + enable_pruning (Optional[bool]): Enable pruning or not. Defaults to `True`. + pruning_interval (Optional[int]): Pruning interval. Defaults to `1001`. + + NOTE: + The `lookups` argument is optional and is used in the sharded case. If not + provided, the module will skip initialization for the dummy module. + The `table_name_to_unpruned_hash_sizes` argument must not be empty. It is a map + of table names to their unpruned hash sizes. + + Example:: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes={"table1": 1000, "table2": 2000}, + lookups=ShardedEmbeddingBagCollection._lookups, + enable_pruning=True, + pruning_interval=1001 + ) + """ + + def __init__( + self, + table_name_to_unpruned_hash_sizes: Dict[str, int], + lookups: Optional[List[nn.Module]] = None, + enable_pruning: bool = True, + pruning_interval: int = 1001, # Default pruning interval 1001 iterations + pg: Optional[dist.ProcessGroup] = None, + table_name_to_sharding_type: Optional[Dict[str, str]] = None, + ) -> None: + + super(GenericITEPModule, self).__init__() + + if not table_name_to_sharding_type: + table_name_to_sharding_type = {} + + # Construct in-training embedding pruning args + self.enable_pruning: bool = enable_pruning + self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {} + self.pruning_interval: int = pruning_interval + self.lookups: Optional[List[nn.Module]] = None if not lookups else lookups + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( + table_name_to_unpruned_hash_sizes + ) + self.table_name_to_sharding_type = table_name_to_sharding_type + + # Map each feature to a physical address_lookup/row_util buffer + self.feature_table_map: Dict[str, int] = {} + self.table_name_to_idx: Dict[str, int] = {} + self.buffer_offsets_list: List[int] = [] + self.idx_to_table_name: Dict[int, str] = {} + # Prevent multi-pruning, after moving iteration counter to outside. + self.last_pruned_iter = -1 + self.pg = pg + + if self.lookups is not None: + self.init_itep_state() + else: + logger.info( + "ITEP init: no lookups provided. Skipping init for dummy module." + ) + + def print_itep_eviction_stats( + self, + pruned_indices_offsets: torch.Tensor, + pruned_indices_total_length: torch.Tensor, + cur_iter: int, + ) -> None: + table_name_to_eviction_ratio = {} + + num_buffers = len(self.buffer_offsets_list) - 1 + for buffer_idx in range(num_buffers): + pruned_start = pruned_indices_offsets[buffer_idx] + pruned_end = pruned_indices_offsets[buffer_idx + 1] + pruned_length = pruned_end - pruned_start + + if pruned_length > 0: + start = self.buffer_offsets_list[buffer_idx] + end = self.buffer_offsets_list[buffer_idx + 1] + buffer_length = end - start + assert buffer_length > 0 + eviction_ratio = pruned_length.item() / buffer_length + table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( + eviction_ratio + ) + + # Sort the mapping by eviction ratio in descending order + sorted_mapping = dict( + sorted( + table_name_to_eviction_ratio.items(), + key=lambda item: item[1], + reverse=True, + ) + ) + # Print the sorted mapping + logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") + + # Calculate percentage of indiced updated/evicted during ITEP iter + pruned_indices_ratio = ( + pruned_indices_total_length / self.buffer_offsets_list[-1] + if self.buffer_offsets_list[-1] > 0 + else 0 + ) + logger.info( + f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." + ) + + def get_table_hash_sizes(self, table: ShardedEmbeddingTable) -> Tuple[int, int]: + unpruned_hash_size = table.num_embeddings + + if table.name in self.table_name_to_unpruned_hash_sizes: + unpruned_hash_size = self.table_name_to_unpruned_hash_sizes[table.name] + else: + # Tables are not pruned by ITEP if table.name not in table_name_to_unpruned_hash_sizes + unpruned_hash_size = table.num_embeddings + logger.info( + f"ITEP: table {table.name} not pruned, because table name is not present in table_name_to_unpruned_hash_sizes." + ) + + return (table.num_embeddings, unpruned_hash_size) + + def create_itep_buffers( + self, + buffer_size: int, + buffer_offsets: List[int], + table_names: List[str], + emb_sizes: List[int], + ) -> None: + """ + Registers ITEP specific buffers in a way that can be accessed by + `torch.ops.fbgemm.init_address_lookup` and can be individually checkpointed. + """ + # Buffers do not enter backward pass + with torch.no_grad(): + # Don't use register_buffer for buffer_offsets and emb_sizes because they + # may change as the sharding plan changes between preemption/resumption + # pyre-fixme[16]: `GenericITEPModule` has no attribute `buffer_offsets`. + self.buffer_offsets = torch.tensor( + buffer_offsets, dtype=torch.int64, device=self.current_device + ) + # pyre-fixme[16]: `GenericITEPModule` has no attribute `emb_sizes`. + self.emb_sizes = torch.tensor( + emb_sizes, dtype=torch.int64, device=self.current_device + ) + + # pyre-fixme[16]: `GenericITEPModule` has no attribute `address_lookup`. + self.address_lookup = torch.zeros( + buffer_size, dtype=torch.int64, device=self.current_device + ) + # pyre-fixme[16]: `GenericITEPModule` has no attribute `row_util`. + self.row_util = torch.zeros( + buffer_size, dtype=torch.float32, device=self.current_device + ) + + # Register buffers + for idx, table_name in enumerate(table_names): + self.register_buffer( + f"{table_name}_itp_address_lookup", + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + self.address_lookup[buffer_offsets[idx] : buffer_offsets[idx + 1]], + ) + self.register_buffer( + f"{table_name}_itp_row_util", + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ... + self.row_util[buffer_offsets[idx] : buffer_offsets[idx + 1]], + ) + + def init_itep_state(self) -> None: + idx = 0 + buffer_size = 0 + # Record address_lookup/row_util buffer lengths and offsets for each feature + buffer_offsets: List[int] = [0] # number of buffers + 1 + table_names: List[str] = [] # number of buffers + 1 + emb_sizes: List[int] = [] # Store embedding table sizes + self.current_device = None + + # Iterate over all tables + # pyre-ignore + for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb in lookup._emb_modules: + + emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables + for table in emb_tables: + # Skip if table was already added previously (if multiple shards assigned to same rank) + if table.name in self.table_name_to_idx: + continue + + ( + pruned_hash_size, + unpruned_hash_size, + ) = self.get_table_hash_sizes(table) + + # Skip tables that are not pruned, aka pruned_hash_size == unpruned_hash_size. + if pruned_hash_size == unpruned_hash_size: + continue + + logger.info( + f"ITEP: Pruning enabled for table {table.name} with features {table.feature_names}, pruned_hash_size {pruned_hash_size} vs unpruned_hash_size {unpruned_hash_size}" + ) + + # buffer size for address_lookup and row_util + buffer_size += unpruned_hash_size + buffer_offsets.append(buffer_size) + table_names.append(table.name) + emb_sizes.append(pruned_hash_size) + + # Create feature to table mappings + for feature_name in table.feature_names: + self.feature_table_map[feature_name] = idx + + # Create table_name to buffer idx mappings + self.table_name_to_idx[table.name] = idx + self.idx_to_table_name[idx] = table.name + idx += 1 + + # Check that all features have the same device + if ( + table.local_metadata is not None + and table.local_metadata.placement is not None + ): + if self.current_device is None: + self.current_device = ( + table.local_metadata.placement.device() + ) + else: + assert ( + self.current_device + == table.local_metadata.placement.device() + ), f"Device of table {table}: {table.local_metadata.placement.device()} does not match existing device: {self.current_device}" + + if self.current_device is None: + self.current_device = torch.device("cuda") + + self.buffer_offsets_list = buffer_offsets + + # Create buffers for address_lookup and row_util + self.create_itep_buffers( + buffer_size=buffer_size, + buffer_offsets=buffer_offsets, + table_names=table_names, + emb_sizes=emb_sizes, + ) + + logger.info( + f"ITEP: done init_state with feature_table_map {self.feature_table_map} and buffer_offsets {self.buffer_offsets_list}" + ) + + # initialize address_lookup + torch.ops.fbgemm.init_address_lookup( + self.address_lookup, + self.buffer_offsets, + self.emb_sizes, + ) + + def reset_weight_momentum( + self, + pruned_indices: torch.Tensor, + pruned_indices_offsets: torch.Tensor, + ) -> None: + if self.lookups is not None: + # pyre-ignore + for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb in lookup._emb_modules: + emb_tables: List[ShardedEmbeddingTable] = ( + emb._config.embedding_tables + ) + + logical_idx = 0 + logical_table_ids = [] + buffer_ids = [] + for table in emb_tables: + name = table.name + if name in self.table_name_to_idx: + buffer_idx = self.table_name_to_idx[name] + start = pruned_indices_offsets[buffer_idx] + end = pruned_indices_offsets[buffer_idx + 1] + length = end - start + if length > 0: + logical_table_ids.append(logical_idx) + buffer_ids.append(buffer_idx) + logical_idx += table.num_features() + + if len(logical_table_ids) > 0: + emb.emb_module.reset_embedding_weight_momentum( + pruned_indices, + pruned_indices_offsets, + torch.tensor( + logical_table_ids, + dtype=torch.int32, + requires_grad=False, + ), + torch.tensor( + buffer_ids, dtype=torch.int32, requires_grad=False + ), + ) + + # Flush UVM cache after ITEP eviction to remove stale states + def flush_uvm_cache(self) -> None: + if self.lookups is not None: + # pyre-ignore + for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb in lookup._emb_modules: + emb.emb_module.flush() + emb.emb_module.reset_cache_states() + + def get_remap_info(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + keys = features.keys() + length_per_key = features.length_per_key() + offset_per_key = features.offset_per_key() + + buffer_idx = [] + feature_lengths = [] + feature_offsets = [] + for i in range(len(keys)): + key = keys[i] + if key not in self.feature_table_map: + continue + buffer_idx.append(self.feature_table_map[key]) + feature_lengths.append(length_per_key[i]) + feature_offsets.append(offset_per_key[i]) + + return [ + torch.tensor(buffer_idx, dtype=torch.int32, device=torch.device("cpu")), + torch.tensor( + feature_lengths, dtype=torch.int64, device=torch.device("cpu") + ), + torch.tensor( + feature_offsets, dtype=torch.int64, device=torch.device("cpu") + ), + ] + + def get_full_values_list(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + inverse_indices = features.inverse_indices() + batch_size = inverse_indices[1].numel() // len(inverse_indices[0]) + keys = features.keys() + if not all(key in self.feature_table_map for key in keys): + keys = [key for key in keys if key in self.feature_table_map] + key_indices = [features._key_indices()[key] for key in keys] + features = features.permute(key_indices) + indices = ( + inverse_indices[1] + if keys == inverse_indices[0] + else reorder_inverse_indices(inverse_indices, keys) + ) + spk_tensor = _pin_and_move( + torch.tensor(features.stride_per_key()), features.device() + ) + offset_indices = ( + indices + _to_offsets(spk_tensor)[:-1].unsqueeze(-1) + ).flatten() + full_values, full_lengths = torch.ops.fbgemm.keyed_jagged_index_select_dim1( + features.values(), + features.lengths(), + features.offsets(), + offset_indices, + features.lengths().numel(), + ) + full_lpk = torch.sum(full_lengths.view(-1, batch_size), dim=1).tolist() + return list(torch.split(full_values, full_lpk)) + + def forward( + self, + sparse_features: KeyedJaggedTensor, + cur_iter: int, + ) -> KeyedJaggedTensor: + """ + Args: + sparse_features (KeyedJaggedTensor]): input embedding lookup indices to be + remapped. + cur_iter (int): iteration counter. + + Returns: + KeyedJaggedTensor: remapped KJT + + NOTE: + We use the same forward method for sharded and non-sharded case. + """ + + if not self.enable_pruning or self.lookups is None: + return sparse_features + + num_buffers = self.buffer_offsets.size(dim=0) - 1 + if num_buffers <= 0: + return sparse_features + + start_pruning: bool = ( + (cur_iter < 10 and (cur_iter + 1) % 3 == 0) + or (cur_iter < 100 and (cur_iter + 1) % 30 == 0) + or (cur_iter < 1000 and (cur_iter + 1) % 300 == 0) + or ((cur_iter + 1) % self.pruning_interval == 0) + ) + if start_pruning and self.training and self.last_pruned_iter != cur_iter: + # Pruning function outputs the indices that need weight/momentum reset + # The indices order is by physical buffer + ( + pruned_indices, + pruned_indices_offsets, + pruned_indices_total_length, + ) = torch.ops.fbgemm.prune_embedding_tables( + cur_iter, + self.pruning_interval, + self.address_lookup, + self.row_util, + self.buffer_offsets, + self.emb_sizes, + ) + # After pruning, reset weight and momentum of pruned indices + if pruned_indices_total_length > 0 and cur_iter > self.pruning_interval: + self.reset_weight_momentum(pruned_indices, pruned_indices_offsets) + + if pruned_indices_total_length > 0: + # Flush UVM cache after every ITEP eviction (every pruning_interval iterations) + self.flush_uvm_cache() + logger.info( + f"ITEP: trying to flush UVM after ITEP eviction, {cur_iter=}" + ) + + self.last_pruned_iter = cur_iter + + # Print eviction stats + self.print_itep_eviction_stats( + pruned_indices_offsets, pruned_indices_total_length, cur_iter + ) + + ( + buffer_idx, + feature_lengths, + feature_offsets, + ) = self.get_remap_info(sparse_features) + + update_utils: bool = ( + (cur_iter < 10) + or (cur_iter < 100 and (cur_iter + 1) % 19 == 0) + or ((cur_iter + 1) % 39 == 0) + ) + full_values_list = None + if update_utils and sparse_features.variable_stride_per_key(): + if sparse_features.inverse_indices_or_none() is not None: + # full util update mode require reconstructing original input indicies from VBE input + full_values_list = self.get_full_values_list(sparse_features) + else: + logger.info( + "Switching to deduped util updating mode due to features missing inverse indices. " + f"features {list(sparse_features.keys())=} with variable stride: {sparse_features.variable_stride_per_key()}" + ) + + remapped_values = torch.ops.fbgemm.remap_indices_update_utils( + cur_iter, + buffer_idx, + feature_lengths, + feature_offsets, + sparse_features.values(), + self.address_lookup, + self.row_util, + self.buffer_offsets, + full_values_list=full_values_list, + ) + + sparse_features._values = remapped_values + + return sparse_features + + +class RowwiseShardedITEPModule(GenericITEPModule): + def _get_local_metadata_idx(self, table: ShardedEmbeddingTable) -> int: + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `shards_metadata` + for i, metadata in enumerate(table.global_metadata.shards_metadata): + if metadata == table.local_metadata: + return i + return -1 + + def _get_local_unpruned_hash_sizes_and_offsets( + self, table: ShardedEmbeddingTable, sharding_type: str + ) -> Tuple[List[int], List[int]]: + """ + Returns a tuples of 2 lists: local_unpruned_hash_sizes, local_offsets. They are used + to create itep buffers and set checkpoint local/global metadata. + """ + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `shards_metadata` + num_devices = len(table.global_metadata.shards_metadata) + + if sharding_type == ShardingType.TABLE_ROW_WISE.value: + i = 0 + for shard in table.global_metadata.shards_metadata: + if table.name not in self.rank_to_virtual_index_mapping: + self.rank_to_virtual_index_mapping[table.name] = {} + self.rank_to_virtual_index_mapping[table.name][ + shard.placement.rank() + ] = i + i += 1 + + global_hash_size = self.table_name_to_unpruned_hash_sizes[table.name] + + block_size: int = math.ceil(global_hash_size / num_devices) + last_rank: int = global_hash_size // block_size + # last_block_size: int = global_hash_size - block_size * last_rank + shard_sizes: List[int] = [] + shard_offsets: List[int] = [0] + + for rank in range(num_devices): + if ( + sharding_type == ShardingType.ROW_WISE.value + or sharding_type == ShardingType.TABLE_ROW_WISE.value + ): + if rank < last_rank: + local_row: int = block_size + elif rank == last_rank: + local_row: int = global_hash_size - block_size * last_rank + else: + local_row: int = 0 + else: + if rank <= last_rank: + local_row: int = block_size + else: + local_row: int = 0 + shard_sizes.append(local_row) + shard_offsets = [0] + + for i in range(num_devices - 1): + shard_offsets.append(shard_sizes[i] + shard_offsets[i]) + + return shard_sizes, shard_offsets + + def get_table_hash_sizes( + self, table: ShardedEmbeddingTable, num_gpus: int = 8 + ) -> Tuple[int, int]: + # calculate local unpruned and pruned hash sizes + local_rows = table.local_rows + assert ( + self.table_name_to_sharding_type is not None + and len(self.table_name_to_sharding_type) > 0 + ), "No sharding type to feature mapping found" + sharding_type = self.table_name_to_sharding_type[table.name] + + if table.name in self.table_name_to_unpruned_hash_sizes: + ( + local_unpruned_shard_sizes, + _, + ) = self._get_local_unpruned_hash_sizes_and_offsets(table, sharding_type) + + if sharding_type == ShardingType.ROW_WISE.value: + + local_unpruned_rows = local_unpruned_shard_sizes[ + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `placement` + table.local_metadata.placement.rank() + ] + else: + + local_unpruned_rows = local_unpruned_shard_sizes[ + self.rank_to_virtual_index_mapping[table.name][ + table.local_metadata.placement.rank() + ] + ] + + else: + # Tables are not pruned by ITEP if table.name not in table_name_to_unpruned_hash_sizes + local_unpruned_rows = local_rows + logger.info( + f"ITEP: table {table.name} not pruned, because table name is not present in table_name_to_unpruned_hash_sizes." + ) + + return (local_rows, local_unpruned_rows) + + def get_buffer_param(self, buffer_params: torch.Tensor, idx: int) -> torch.Tensor: + start = self.buffer_offsets_list[idx] + end = self.buffer_offsets_list[idx + 1] + length = end - start + return buffer_params.detach()[start:end].view(length, 1) + + def get_itp_state_dict( + self, + embedding_tables: List[ShardedEmbeddingTable], + params: torch.Tensor, + pg: Optional[dist.ProcessGroup] = None, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + suffix: str = "", + dtype: torch.dtype = torch.float32, + ) -> Dict[str, Any]: + def get_param(params: torch.Tensor, idx: int) -> torch.Tensor: + start = self.buffer_offsets_list[idx] + end = self.buffer_offsets_list[idx + 1] + length = end - start + return params.detach()[start:end].view(length) + + def get_key_from_table_name_and_suffix( + table_name: str, prefix: str, buffer_suffix: str + ) -> str: + return prefix + f"{table_name}{buffer_suffix}" + + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + ckp_tables: List[str] = [] + skipped_tables: List[str] = [] + + assert ( + self.table_name_to_sharding_type is not None + and len(self.table_name_to_sharding_type) > 0 + ), "No sharding type to feature mapping found" + + for table in embedding_tables: + # Skip tables that are not wrapped with itep + if table.name not in self.table_name_to_idx.keys(): + skipped_tables.append(table.name) + continue + ckp_tables.append(table.name) + # Create buffer key and param slice + key = get_key_from_table_name_and_suffix(table.name, prefix, suffix) + param_idx = self.table_name_to_idx[table.name] + buffer_param: torch.Tensor = get_param(params, param_idx) + sharding_type = self.table_name_to_sharding_type[table.name] # pyre-ignore + + # For inference there is no pg, all tensors are local + if table.global_metadata is not None and pg is not None: + # Get unpruned global and local hashsizes + ( + unpruned_row_sizes, + unpruned_row_offsets, + ) = self._get_local_unpruned_hash_sizes_and_offsets( + table, sharding_type + ) + global_unpruned_hash_size = self.table_name_to_unpruned_hash_sizes[ + table.name + ] + # Build global shards metadata + global_shards_metadata: List[ShardMetadata] = [] + for i, table_global_metadata in enumerate( + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `shards_metadata` + table.global_metadata.shards_metadata + ): + shard_sizes = [unpruned_row_sizes[i]] + shard_offsets = [unpruned_row_offsets[i]] + placement = copy.deepcopy(table_global_metadata.placement) + global_shards_metadata.append( + ShardMetadata( + shard_sizes=shard_sizes, + shard_offsets=shard_offsets, + placement=placement, + ) + ) + itp_global_metadata = ShardedTensorMetadata( + shards_metadata=global_shards_metadata, + size=torch.Size([global_unpruned_hash_size]), + ) + itp_global_metadata.tensor_properties.dtype = dtype + itp_global_metadata.tensor_properties.requires_grad = False + # Build local shard metadata + local_idx = self._get_local_metadata_idx(table) + itp_local_medadata = global_shards_metadata[local_idx] + + destination[key] = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=[Shard(buffer_param, itp_local_medadata)], + sharded_tensor_metadata=itp_global_metadata, + process_group=pg, + ) + ) + else: + destination[key] = buffer_param + + logger.info( + f"ITEP: get_itp_state_dict for {suffix}), got {ckp_tables}, skippped {skipped_tables}" + ) + return destination + + # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` inconsistently. + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + loaded_keys = [] + unexpected_keys = list(state_dict.keys()) + for key, dst_param in self.state_dict().items(): + if key in state_dict: + src_param = state_dict[key] + if isinstance(dst_param, ShardedTensor): + assert isinstance(src_param, ShardedTensor) + assert len(dst_param.local_shards()) == len( + src_param.local_shards() + ) + for dst_local_shard, src_local_shard in zip( + dst_param.local_shards(), src_param.local_shards() + ): + assert ( + dst_local_shard.metadata.shard_offsets + == src_local_shard.metadata.shard_offsets + ) + assert ( + dst_local_shard.metadata.shard_sizes + == src_local_shard.metadata.shard_sizes + ) + + dst_local_shard.tensor.detach().copy_(src_local_shard.tensor) + else: + assert isinstance(src_param, torch.Tensor) and isinstance( + dst_param, torch.Tensor + ) + dst_param.detach().copy_(src_param) + unexpected_keys.remove(key) + loaded_keys.append(key) + else: + missing_keys.append(key) + + logger.info( + f"ITEP: load_state_dict, loaded {loaded_keys}, missed {missing_keys}, , unexpected {unexpected_keys}" + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + # pyre-fixme[14]: `state_dict` overrides method defined in `nn.modules.module.Module` inconsistently. + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + if self.lookups is not None: + # pyre-ignore [16] + for lookup in self.lookups: + list_of_tables: List[ShardedEmbeddingTable] = [] + for emb_config in lookup.grouped_configs: + list_of_tables.extend(emb_config.embedding_tables) + + destination = self.get_itp_state_dict( + list_of_tables, + self.address_lookup, # pyre-ignore + self.pg, + destination, + prefix, + suffix="_itp_address_lookup", + dtype=torch.int64, + ) + destination = self.get_itp_state_dict( + list_of_tables, + self.row_util, # pyre-ignore + self.pg, + destination, + prefix, + suffix="_itp_row_util", + dtype=torch.float32, + ) + return destination diff --git a/torchrec/modules/keyed_jagged_tensor_pool.py b/torchrec/modules/keyed_jagged_tensor_pool.py new file mode 100644 index 000000000..419232e1f --- /dev/null +++ b/torchrec/modules/keyed_jagged_tensor_pool.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, List, Optional + +import torch +from torchrec.modules.object_pool import ObjectPool +from torchrec.modules.object_pool_lookups import ( + KeyedJaggedTensorPoolLookup, + TensorJaggedIndexSelectLookup, + UVMCachingInt64Lookup, +) +from torchrec.modules.utils import deterministic_dedup, jagged_index_select_with_empty +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +@torch.fx.wrap +def _fx_assert_device(ids: torch.Tensor, device: torch.device) -> None: + assert ids.device == device + assert ids.dtype in [torch.int32, torch.int64] + + +@torch.fx.wrap +def _fx_wrap_lookup( + ids: torch.Tensor, + keys: List[str], + feature_max_lengths: Dict[str, int], + is_weighted: bool, + values_dtype: torch.dtype, + device: torch.device, + lookup: TensorJaggedIndexSelectLookup, # This type enforement is a hack to make it work with torch.jit.script + weigth_dtype: Optional[torch.dtype] = None, +) -> KeyedJaggedTensor: + jt_lookup: JaggedTensor = lookup.lookup(ids) + + row_major_to_feature_major_permute = ( + torch.arange((ids.shape[0] * len(feature_max_lengths)), device=device) + .view(-1, len(feature_max_lengths)) + .t() + .flatten() + ) + + lengths = jt_lookup.lengths().flatten()[row_major_to_feature_major_permute] + output_offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths) + values = jagged_index_select_with_empty( + jt_lookup.values().flatten().unsqueeze(-1), + row_major_to_feature_major_permute, + jt_lookup.offsets().flatten()[1:], + output_offsets, + ) + values, lengths = values.flatten(), lengths.flatten() + + weights = torch.jit.annotate(Optional[torch.Tensor], None) + if jt_lookup.weights_or_none() is not None: + weights = jagged_index_select_with_empty( + jt_lookup.weights().flatten().unsqueeze(-1), + row_major_to_feature_major_permute, + jt_lookup.offsets().flatten()[1:], + output_offsets, + ) + weights = weights.flatten() + + return KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=values, + lengths=lengths, + weights=weights, + ) + + +class KeyedJaggedTensorPool(ObjectPool[KeyedJaggedTensor]): + """ + KeyedJaggedTensorPool represents a collection of KeyedJaggedTensor (KJT) + with an index over the batch (non-jagged) dimension. For example, if a KJT + has 2 features "Feature0" and "Feature1" with stride of 2 (i.e. batch dim = 2), + KeyedJaggedTensorPool allows associating an index with batch 0 of "Feature0" and + "Feature1". Example: + + # "Feature0" "Feature1" < dim_0 + # batch_0 [V0,V1] None <-- associated with index 2 + # batch_1 [V3] [V4] <-- associated with index 0 + # ^ + # dim_1 + + This is useful when one needs to associate entity IDs to sparse features with + jagged dimension, for example during hard negative sampling. + + Args: + pool_size (int): total number of batches that can be stored in the pool + feature_max_lengths (Dict[str,int]): Mapping from feature name in KJT + to the maximum size of the jagged slices for the feature. + is_weighted (bool): whether KJT values have weights that need to be stored. + device (Optional[torch.device]): default device + enable_uvm (bool): if set to true, the pool will be allocated on UVM + + Call Args: + ids: 1D torch.Tensor of ids to look up + + Returns: + KeyedJaggedTensor with uniform stride of ids.size(0) + + Example:: + + feature_max_lengths = {"feature0": 2, "feature1": 3} + + kjt_pool = KeyedJaggedTensorPool( + pool_size=10, + feature_max_lengths=feature_max_lengths, + values_dtype=torch.float, + ) + + # Update + kjt_pool.update( + ids=torch.tensor([1,0,2]), # Assign different indices along batch dim + values=kjt, + ) + + # Lookup + lookup_kjt = kjt_pool.lookup(ids=torch.Tensor([2,0])) + + print(lookup_kjt) + # KeyedJaggedTensor({ + # "feature0": [[v2], [v0, v1]] + # "feature1": [[v5,v6,v7], [v4]] + # }) + + """ + + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + values_dtype: torch.dtype = torch.int64, + is_weighted: bool = False, + device: Optional[torch.device] = None, + enable_uvm: bool = False, + ) -> None: + super().__init__() + self._pool_size = pool_size + self._feature_max_lengths: Dict[str, int] = feature_max_lengths + # pyre-fixme[4]: Attribute must be annotated. + self._total_lengths = sum(self._feature_max_lengths.values()) + self._values_dtype = values_dtype + self._is_weighted = is_weighted + # pyre-fixme[4]: Attribute must be annotated. + self._device = device if device is not None else torch.device("meta") + self._enable_uvm = enable_uvm + + # pyre-fixme[4]: Attribute must be annotated. + self._permute_feature = None + self.register_buffer( + "_feature_max_lengths_t", + torch.tensor( + list(feature_max_lengths.values()), + dtype=torch.int32, + device=self._device, + ), + persistent=False, + ) + + # pyre-fixme[4]: Attribute must be annotated. + self._keys = list(self._feature_max_lengths.keys()) + # pyre-ignore + self._lookup: KeyedJaggedTensorPoolLookup = None + if self._enable_uvm and values_dtype == torch.int64: + self._lookup = UVMCachingInt64Lookup( + pool_size, feature_max_lengths, is_weighted, self._device + ) + else: + self._lookup = TensorJaggedIndexSelectLookup( + pool_size, + values_dtype, + feature_max_lengths, + is_weighted, + self._device, + ) + + if self._lookup is None: + raise ValueError( + f"Cannot create lookup for {self._enable_uvm=} {self._values_dtype}" + ) + + for fqn, tensor in self._lookup.states_to_register(): + self.register_buffer( + fqn, + tensor, + ) + + def _load_from_state_dict( + self, + state_dict: Dict[str, torch.Tensor], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + + # pyre-fixme[16]: `KeyedJaggedTensorPoolLookup` has no attribute `_values`. + self._lookup._values = state_dict[prefix + "values"] + self._lookup._key_lengths = state_dict[prefix + "key_lengths"] + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def feature_max_lengths(self) -> Dict[str, int]: + return self._feature_max_lengths + + @property + def values_dtype(self) -> torch.dtype: + return self._values_dtype + + @property + def is_weighted(self) -> bool: + return self._is_weighted + + def lookup(self, ids: torch.Tensor) -> KeyedJaggedTensor: + _fx_assert_device(ids, self._device) + return _fx_wrap_lookup( + ids, + self._keys, + self._feature_max_lengths, + self._is_weighted, + self._values_dtype, + self._device, + self._lookup, + self._weights.dtype if self._is_weighted else None, + ) + + def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor: + """ + 2 steps: + 1. Permute/filter KJT keys to be the same as in feature_max_lengths + 2. Ensure the max_lengths of input is within the feature_max_lengths + """ + if self._permute_feature is None: + self._permute_feature = [] + for feature in self._feature_max_lengths.keys(): + for j, kjt_feature in enumerate(values.keys()): + if feature == kjt_feature: + self._permute_feature.append(j) + + valid_input = values.permute(self._permute_feature) + max_elements, _max_indices = ( + valid_input.lengths().reshape(len(self._keys), -1).max(dim=1) + ) + + assert torch.all( + max_elements <= self._feature_max_lengths_t + ).item(), "input KJT has a feature that exceeds specified max lengths" + + return valid_input + + def update(self, ids: torch.Tensor, values: KeyedJaggedTensor) -> None: + _fx_assert_device(ids, self._device) + + kjt = self._update_preproc(values) + assert kjt.values().dtype == self._values_dtype + + # If duplicate ids are passed in for update, only the last one is kept + deduped_ids, dedup_permutation = deterministic_dedup(ids) + arange_idx = torch.arange( + values.stride() * len(self._keys), device=self._device + ) + feature_major_to_row_major_permute = (arange_idx.view(len(self._keys), -1).t())[ + dedup_permutation, : + ].flatten() + + row_major_lengths = kjt.lengths()[feature_major_to_row_major_permute] + row_major_offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum( + row_major_lengths + ) + row_major_values = jagged_index_select_with_empty( + kjt.values().unsqueeze(-1), + feature_major_to_row_major_permute, + kjt.offsets()[1:], + row_major_offsets, + ) + + row_major_values = row_major_values.flatten() + + row_major_lengths = row_major_lengths.flatten() + + row_major_weights = None + if self._is_weighted: + row_major_weights = jagged_index_select_with_empty( + kjt.weights().unsqueeze(-1), + feature_major_to_row_major_permute, + kjt.offsets()[1:], + row_major_offsets, + ) + row_major_weights = row_major_weights.flatten() + + self._lookup.update( + deduped_ids, + JaggedTensor( + values=row_major_values, + lengths=row_major_lengths.flatten(), + weights=row_major_weights, + ), + ) diff --git a/torchrec/modules/lazy_extension.py b/torchrec/modules/lazy_extension.py index 5be6764a3..195c5c3ba 100644 --- a/torchrec/modules/lazy_extension.py +++ b/torchrec/modules/lazy_extension.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import functools import inspect from typing import Any, Callable @@ -28,14 +30,12 @@ def _apply_functions_after_first_forward( ) -> None: _functions_to_lazy_apply = getattr(module, "_functions_to_lazy_apply", None) if _functions_to_lazy_apply is not None: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, - # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor, - # torch.nn.modules.module.Module]` is not a function. + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a + # function. for fn in _functions_to_lazy_apply: module.apply(fn) delattr(module, "_functions_to_lazy_apply") - # pyre-ignore[16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `remove`. module._lazy_apply_hook.remove() delattr(module, "_lazy_apply_hook") @@ -83,22 +83,21 @@ def init_weights(m): """ if not hasattr(module, "_functions_to_lazy_apply"): - # pyre-ignore[16] + # pyre-fixme[16]: `Module` has no attribute `_functions_to_lazy_apply`. module._functions_to_lazy_apply = [] if not hasattr(module, "_lazy_apply_hook"): - # pyre-ignore[16] + # pyre-fixme[16]: `Module` has no attribute `_lazy_apply_hook`. module._lazy_apply_hook = module.register_forward_hook( _apply_functions_after_first_forward ) - # pyre-ignore[16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `append`. module._functions_to_lazy_apply.append(fn) return module class _LazyExtensionProtocol(_LazyProtocol): # pyre-ignore[2,3] - def _call_impl(self, *input, **kwargs): - ... + def _call_impl(self, *input, **kwargs): ... class LazyModuleExtensionMixin(LazyModuleMixin): @@ -165,9 +164,10 @@ def init_weights(m): # fmt: off # pyre-ignore[2, 47] - def _infer_parameters(self: _LazyExtensionProtocol, module, input, kwargs) -> None: - r"""Infers the size and initializes the parameters according to the - provided input batch. + # `LazyModuleMixin` inconsistently. + def _infer_parameters(self: _LazyExtensionProtocol, module, args, kwargs) -> None: + r"""Infers the size and initializes the parameters according to the provided input batch. + Given a module that contains parameters that were declared inferrable using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass in the complete module using the provided input to initialize all the parameters @@ -175,9 +175,10 @@ def _infer_parameters(self: _LazyExtensionProtocol, module, input, kwargs) -> No The module is set into evaluation mode before running the forward pass in order to avoid saving statistics or calculating gradients """ - module.initialize_parameters(*input, **kwargs) + kwargs = kwargs if kwargs else {} + module.initialize_parameters(*args, **kwargs) if module.has_uninitialized_params(): - raise RuntimeError('module {} has not been fully initialized'.format(self._get_name())) + raise RuntimeError(f'module {self._get_name()} has not been fully initialized') module._initialize_hook.remove() module._load_hook.remove() delattr(module, '_initialize_hook') @@ -256,4 +257,5 @@ def _call_impl(self, *input, **kwargs): # noqa: C901 # fmt: on # pyre-ignore[4] + # pyre-fixme[15]: `__call__` overrides attribute defined in `type` inconsistently. __call__: Callable[..., Any] = _call_impl diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py new file mode 100644 index 000000000..6e7850dba --- /dev/null +++ b/torchrec/modules/mc_embedding_modules.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import cast, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_modules import ManagedCollisionCollection +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +def evict( + evictions: Dict[str, Optional[torch.Tensor]], + ebc: nn.Module, +) -> None: + # TODO: write function + return + + +class BaseManagedCollisionEmbeddingCollection(nn.Module): + """ + BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules. + The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection. + + Args: + embedding_module: EmbeddingCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. + + """ + + def __init__( + self, + embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + super().__init__() + self._managed_collision_collection = managed_collision_collection + self._return_remapped_features = return_remapped_features + self._allow_in_place_embed_weight_update = allow_in_place_embed_weight_update + self._embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection] = ( + embedding_module + ) + + if isinstance(embedding_module, EmbeddingBagCollection): + assert ( + # pyre-fixme[29]: `Union[(self: EmbeddingBagCollection) -> + # list[EmbeddingBagConfig], Module, Tensor]` is not a function. + self._embedding_module.embedding_bag_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Bag Collection and Managed Collision Collection must contain the Embedding Configs" + + else: + assert ( + # pyre-fixme[29]: `Union[(self: EmbeddingCollection) -> + # list[EmbeddingConfig], Module, Tensor]` is not a function. + self._embedding_module.embedding_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs" + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + + features = self._managed_collision_collection(features) + + embedding_res = self._embedding_module(features) + + evict(self._managed_collision_collection.evict(), self._embedding_module) + + if not self._return_remapped_features: + return embedding_res, None + return embedding_res, features + + +class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules. + The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection. + + For details of input and output types, see EmbeddingCollection + + Args: + embedding_collection: EmbeddingCollection to lookup embeddings + managed_collision_collection: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + allow_in_place_embed_weight_update(bool): enable in place update of embedding + weights on evict. This flag when enabled will allow update embedding + weights without modifying of autograd graph. + + """ + + def __init__( + self, + embedding_collection: EmbeddingCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + super().__init__( + embedding_collection, + managed_collision_collection, + return_remapped_features, + allow_in_place_embed_weight_update, + ) + + # For consistency with embedding bag collection + @property + def _embedding_collection(self) -> EmbeddingCollection: + return cast(EmbeddingCollection, self._embedding_module) + + +class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. + The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. + + For details of input and output types, see EmbeddingBagCollection + + Args: + embedding_module: EmbeddingBagCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. + + + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + super().__init__( + embedding_bag_collection, + managed_collision_collection, + return_remapped_features, + allow_in_place_embed_weight_update, + ) + + # For backwards compat, as references existed in tests + @property + def _embedding_bag_collection(self) -> EmbeddingBagCollection: + return cast(EmbeddingBagCollection, self._embedding_module) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py new file mode 100644 index 000000000..a472b88e1 --- /dev/null +++ b/torchrec/modules/mc_modules.py @@ -0,0 +1,1424 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import abc +from logging import getLogger, Logger +from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import torch + +from torch import nn +from torchrec.modules.embedding_configs import BaseEmbeddingConfig +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +logger: Logger = getLogger(__name__) + + +@torch.fx.wrap +def apply_mc_method_to_jt_dict( + mc_module: nn.Module, + method: str, + features_dict: Dict[str, JaggedTensor], +) -> Dict[str, JaggedTensor]: + """ + Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering + """ + attr = getattr(mc_module, method) + return attr(features_dict) + + +@torch.fx.wrap +def _update( + base: Optional[Dict[str, JaggedTensor]], delta: Dict[str, JaggedTensor] +) -> Dict[str, JaggedTensor]: + res: Dict[str, JaggedTensor] = {} + if base is not None: + for k, v in base.items(): + res[k] = v + for k, v in delta.items(): + res[k] = v + return res + + +@torch.fx.wrap +def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: + return torch.cat([jt.values() for jt in jd.values()]) + + +# TODO: keep the old implementation for backward compatibility and will remove it later +@torch.fx.wrap +def _mcc_lazy_init( + features: KeyedJaggedTensor, + feature_names: List[str], + features_order: List[int], + created_feature_order: bool, +) -> Tuple[KeyedJaggedTensor, bool, List[int]]: # features_order + input_feature_names: List[str] = features.keys() + if not created_feature_order: + for f in feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order == list(range(len(input_feature_names))): + features_order = torch.jit.annotate(List[int], []) + created_feature_order = True + + if len(features_order) > 0: + features = features.permute( + features_order, + ) + + return (features, created_feature_order, features_order) + + +@torch.fx.wrap +def _mcc_lazy_init_inplace( + features: KeyedJaggedTensor, + feature_names: List[str], + features_order: List[int], + created_feature_order: List[bool], +) -> KeyedJaggedTensor: + input_feature_names: List[str] = features.keys() + if not created_feature_order or not created_feature_order[0]: + for f in feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order == list(range(len(input_feature_names))): + features_order.clear() + + if len(created_feature_order) > 0: + created_feature_order[0] = True + else: + created_feature_order.append(True) + + if len(features_order) > 0: + features = features.permute( + features_order, + ) + + return features + + +@torch.fx.wrap +def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor: + return torch.tensor(kjt.length_per_key()) + + +@torch.no_grad() +def dynamic_threshold_filter( + id_counts: torch.Tensor, + threshold_skew_multiplier: float = 10.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is total_count / num_ids * threshold_skew_multiplier. An id is + added if its count is strictly greater than the threshold. + """ + + num_ids = id_counts.numel() + total_count = id_counts.sum() + + BASE_THRESHOLD = 1 / num_ids + threshold_mass = BASE_THRESHOLD * threshold_skew_multiplier + + threshold = threshold_mass * total_count + threshold_mask = id_counts > threshold + + return threshold_mask, threshold + + +@torch.no_grad() +def average_threshold_filter( + id_counts: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is average of id_counts. An id is added if its count is strictly + greater than the mean. + """ + if id_counts.dtype != torch.float: + id_counts = id_counts.float() + threshold = id_counts.mean() + threshold_mask = id_counts > threshold + + return threshold_mask, threshold + + +@torch.no_grad() +def probabilistic_threshold_filter( + id_counts: torch.Tensor, + per_id_probability: float = 0.01, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Each id has probability per_id_probability of being added. For example, + if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% + of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, + and for a randomly generated threshold, the id score is the chance of it being added. + """ + probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float) + id_scores = 1 - torch.pow(probability, id_counts) + + threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device) + threshold_mask = id_scores > threshold + + return threshold_mask, threshold + + +class ManagedCollisionModule(nn.Module): + """ + Abstract base class for ManagedCollisionModule. + Maps input ids to range [0, max_output_id). + + Args: + max_output_id (int): Max output value of remapped ids. + input_hash_size (int): Max value of input range i.e. [0, input_hash_size) + remapping_range_start_index (int): Relative start index of remapping range + device (torch.device): default compute device. + + Example:: + jt = JaggedTensor(...) + mcm = ManagedCollisionModule(...) + mcm_jt = mcm(fp) + """ + + def __init__( + self, + device: torch.device, + output_segments: List[int], + skip_state_validation: bool = False, + ) -> None: + super().__init__() + self._device = device + + if skip_state_validation: + logger.warning( + "Skipping validation on ManagedCollisionModule. This module may not be Reshard-able as a result" + ) + return + + # limited to max of 1024 RW shards + assert ( + len(output_segments) <= 1025 + ), "ManagedCollisionModule limited to 1024 shards" + + self.register_buffer( + "_output_segments_tensor", + torch.tensor( + output_segments + [-1] * (1025 - len(output_segments)), + dtype=torch.int64, + device=self.device, + ), + ) + self.register_buffer( + "_current_iter_tensor", + torch.tensor( + [0], + dtype=torch.int64, + device=self.device, + ), + ) + + def _load_state_dict_post_hook( + module: "ManagedCollisionModule", + incompatible_keys: torch.nn.modules.module._IncompatibleKeys, + ) -> None: + module.validate_state() + + self.register_load_state_dict_post_hook(_load_state_dict_post_hook) + + @abc.abstractmethod + def preprocess( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + pass + + @property + def device(self) -> torch.device: + return self._device + + @abc.abstractmethod + def evict(self) -> Optional[torch.Tensor]: + """ + Returns None if no eviction should be done this iteration. Otherwise, return ids of slots to reset. + On eviction, this module should reset its state for those slots, with the assumptionn that the downstream module + will handle this properly. + """ + pass + + @abc.abstractmethod + def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + pass + + @abc.abstractmethod + def profile(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + pass + + @abc.abstractmethod + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + pass + + @abc.abstractmethod + def output_size(self) -> int: + """ + Returns numerical range of output, for validation vs. downstream embedding lookups + """ + pass + + @abc.abstractmethod + def input_size(self) -> int: + """ + Returns numerical range of input, for sharding info + """ + pass + + @abc.abstractmethod + def buckets(self) -> int: + """ + Returns number of uniform buckets, relevant to resharding + """ + pass + + @abc.abstractmethod + def validate_state(self) -> None: + """ + Validates that the state of the module after loading from checkpoint + """ + pass + + @abc.abstractmethod + def open_slots(self) -> torch.Tensor: + """ + Returns number of unused slots in managed collision module + """ + pass + + @abc.abstractmethod + def rebuild_with_output_id_range( + self, + output_id_range: Tuple[int, int], + output_segments: List[int], + device: Optional[torch.device] = None, + ) -> "ManagedCollisionModule": + """ + Used for creating local MC modules for RW sharding + """ + pass + + +class ManagedCollisionCollection(nn.Module): + """ + ManagedCollisionCollection represents a collection of managed collision modules. + The inputs passed to the MCC will be remapped by the managed collision modules + and returned. + Args: + managed_collision_modules (Dict[str, ManagedCollisionModule]): Dict of managed collision modules + embedding_confgs (List[BaseEmbeddingConfig]): List of embedding configs, for each table with a managed collsion module + """ + + _table_to_features: Dict[str, List[str]] + _features_order: List[int] + _created_feature_order: List[bool] # use list for inplace update in leaf function + + def __init__( + self, + managed_collision_modules: Dict[str, ManagedCollisionModule], + embedding_configs: Sequence[BaseEmbeddingConfig], + need_preprocess: bool = True, + ) -> None: + super().__init__() + self._managed_collision_modules = nn.ModuleDict(managed_collision_modules) + self._embedding_configs = embedding_configs + self.need_preprocess = need_preprocess + self._feature_to_table: Dict[str, str] = { + feature: config.name + for config in embedding_configs + for feature in config.feature_names + } + self._table_to_features: Dict[str, List[str]] = { + config.name: config.feature_names for config in embedding_configs + } + + self._table_feature_splits: List[int] = [ + len(features) for features in self._table_to_features.values() + ] + + table_to_config = {config.name: config for config in embedding_configs} + + for name, config in table_to_config.items(): + if name not in managed_collision_modules: + raise ValueError( + f"Table {name} is not present in managed_collision_modules" + ) + assert ( + managed_collision_modules[name].output_size() == config.num_embeddings + ), ( + f"max_output_id in managed collision module for {name} " + f"must match {config.num_embeddings}" + ) + self._feature_names: List[str] = [ + feature for config in embedding_configs for feature in config.feature_names + ] + self._created_feature_order: List[bool] = [False] + self._features_order = [] + + def _create_feature_order( + self, + input_feature_names: List[str], + device: torch.device, + ) -> None: + features_order: List[int] = [] + for f in self._feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order != list(range(len(features_order))): + self._features_order = features_order + + def embedding_configs(self) -> Sequence[BaseEmbeddingConfig]: + return self._embedding_configs + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedJaggedTensor: + features = _mcc_lazy_init_inplace( + features, + self._feature_names, + self._features_order, + self._created_feature_order, + ) + + feature_splits: List[KeyedJaggedTensor] = features.split( + self._table_feature_splits + ) + + output: Optional[Dict[str, JaggedTensor]] = None + for i, (table, mc_module) in enumerate(self._managed_collision_modules.items()): + kjt: KeyedJaggedTensor = feature_splits[i] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + weights=_get_length_per_key(kjt), + ) + } + mc_input = mc_module(mc_input) + output = _update(output, mc_input) + + assert output is not None + values: torch.Tensor = _cat_jagged_values(output) + return KeyedJaggedTensor( + keys=features.keys(), + values=values, + lengths=features.lengths(), + weights=features.weights_or_none(), + ) + + def evict(self) -> Dict[str, Optional[torch.Tensor]]: + evictions: Dict[str, Optional[torch.Tensor]] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + evictions[table] = managed_collision_module.evict() + return evictions + + def open_slots(self) -> Dict[str, torch.Tensor]: + open_slots: Dict[str, torch.Tensor] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + open_slots[table] = managed_collision_module.open_slots() + return open_slots + + +class MCHEvictionPolicyMetadataInfo(NamedTuple): + metadata_name: str + is_mch_metadata: bool + is_history_metadata: bool + + +class MCHEvictionPolicy(abc.ABC): + def __init__( + self, + metadata_info: List[MCHEvictionPolicyMetadataInfo], + threshold_filtering_func: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]] + ] = None, # experimental + ) -> None: + """ + threshold_filtering_func (Optional[Callable]): function used to filter incoming ids before update/eviction. experimental feature. + [input: Tensor] the function takes as input a 1-d tensor of unique id counts. + [output1: Tensor] the function returns a boolean_mask or index array of corresponding elements in the input tensor that pass the filter. + [output2: float, Tensor] the function returns the threshold that will be used to filter ids before update/eviction. all values <= this value will be filtered out. + + """ + self._metadata_info = metadata_info + self._threshold_filtering_func = threshold_filtering_func + + @property + @abc.abstractmethod + def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]: + pass + + @abc.abstractmethod + def record_history_metadata( + self, + current_iter: int, + incoming_ids: torch.Tensor, + history_metadata: Dict[str, torch.Tensor], + ) -> None: + """ + Args: + current_iter (int): current iteration + incoming_ids (torch.Tensor): incoming ids + history_metadata (Dict[str, torch.Tensor]): history metadata dict + + Compute and record metadata based on incoming ids + for the implemented eviction policy. + """ + pass + + @abc.abstractmethod + def coalesce_history_metadata( + self, + current_iter: int, + history_metadata: Dict[str, torch.Tensor], + unique_ids_counts: torch.Tensor, + unique_inverse_mapping: torch.Tensor, + additional_ids: Optional[torch.Tensor] = None, + threshold_mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """ + Args: + history_metadata (Dict[str, torch.Tensor]): history metadata dict + additional_ids (torch.Tensor): additional ids to be used as part of history + unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from + torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor + indices to their coalesced tensor indices. + + Coalesce metadata history buffers and return dict of processed metadata tensors. + """ + pass + + @abc.abstractmethod + def update_metadata_and_generate_eviction_scores( + self, + current_iter: int, + mch_size: int, + coalesced_history_argsort_mapping: torch.Tensor, + coalesced_history_sorted_unique_ids_counts: torch.Tensor, + coalesced_history_mch_matching_elements_mask: torch.Tensor, + coalesced_history_mch_matching_indices: torch.Tensor, + mch_metadata: Dict[str, torch.Tensor], + coalesced_history_metadata: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + + + Returns Tuple of (evicted_indices, selected_new_indices) where: + evicted_indices are indices in the mch map to be evicted, and + selected_new_indices are the indices of the ids in the coalesced + history that are to be added to the mch. + """ + pass + + def _compute_selected_eviction_and_replacement_indices( + self, + pivot: int, + eviction_scores: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE these are like indices + argsorted_eviction_scores = torch.argsort( + eviction_scores, descending=True, stable=True + ) + + # indices with values >= zch_size in the top zch_size scores correspond + # to new incoming ids to be added to zch + selected_new_ids_mask = argsorted_eviction_scores[:pivot] >= pivot + # indices with values < zch_size outside the top zch_size scores correspond + # to existing zch ids to be evicted + evicted_ids_mask = argsorted_eviction_scores[pivot:] < pivot + evicted_indices = argsorted_eviction_scores[pivot:][evicted_ids_mask] + selected_new_indices = ( + argsorted_eviction_scores[:pivot][selected_new_ids_mask] - pivot + ) + + return evicted_indices, selected_new_indices + + +class LFU_EvictionPolicy(MCHEvictionPolicy): + def __init__( + self, + threshold_filtering_func: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]] + ] = None, # experimental + ) -> None: + super().__init__( + metadata_info=[ + MCHEvictionPolicyMetadataInfo( + metadata_name="counts", + is_mch_metadata=True, + is_history_metadata=False, + ), + ], + threshold_filtering_func=threshold_filtering_func, + ) + + @property + def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]: + return self._metadata_info + + def record_history_metadata( + self, + current_iter: int, + incoming_ids: torch.Tensor, + history_metadata: Dict[str, torch.Tensor], + ) -> None: + # no-op; no history buffers + pass + + def coalesce_history_metadata( + self, + current_iter: int, + history_metadata: Dict[str, torch.Tensor], + unique_ids_counts: torch.Tensor, + unique_inverse_mapping: torch.Tensor, + additional_ids: Optional[torch.Tensor] = None, + threshold_mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + # no-op; no history buffers + return {} + + def update_metadata_and_generate_eviction_scores( + self, + current_iter: int, + mch_size: int, + coalesced_history_argsort_mapping: torch.Tensor, + coalesced_history_sorted_unique_ids_counts: torch.Tensor, + coalesced_history_mch_matching_elements_mask: torch.Tensor, + coalesced_history_mch_matching_indices: torch.Tensor, + mch_metadata: Dict[str, torch.Tensor], + coalesced_history_metadata: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + mch_counts = mch_metadata["counts"] + # update metadata for matching ids + mch_counts[ + coalesced_history_mch_matching_indices + ] += coalesced_history_sorted_unique_ids_counts[ + coalesced_history_mch_matching_elements_mask + ] + + # incoming non-matching ids + new_sorted_uniq_ids_counts = coalesced_history_sorted_unique_ids_counts[ + ~coalesced_history_mch_matching_elements_mask + ] + + # TODO: find cleaner way to avoid last element of zch + + mch_counts[mch_size - 1] = torch.iinfo(torch.int64).max + + merged_counts = torch.cat( + [ + mch_counts, + new_sorted_uniq_ids_counts, + ] + ) + # calculate evicted and replacement indices + ( + evicted_indices, + selected_new_indices, + ) = self._compute_selected_eviction_and_replacement_indices( + mch_size, + merged_counts, + ) + + # update metadata for evicted ids + mch_counts[evicted_indices] = new_sorted_uniq_ids_counts[selected_new_indices] + + return evicted_indices, selected_new_indices + + +class LRU_EvictionPolicy(MCHEvictionPolicy): + def __init__( + self, + decay_exponent: float = 1.0, + threshold_filtering_func: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]] + ] = None, # experimental + ) -> None: + super().__init__( + metadata_info=[ + MCHEvictionPolicyMetadataInfo( + metadata_name="last_access_iter", + is_mch_metadata=True, + is_history_metadata=True, + ), + ], + threshold_filtering_func=threshold_filtering_func, + ) + self._decay_exponent = decay_exponent + + @property + def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]: + return self._metadata_info + + def record_history_metadata( + self, + current_iter: int, + incoming_ids: torch.Tensor, + history_metadata: Dict[str, torch.Tensor], + ) -> None: + history_last_access_iter = history_metadata["last_access_iter"] + history_last_access_iter[:] = current_iter + + def coalesce_history_metadata( + self, + current_iter: int, + history_metadata: Dict[str, torch.Tensor], + unique_ids_counts: torch.Tensor, + unique_inverse_mapping: torch.Tensor, + additional_ids: Optional[torch.Tensor] = None, + threshold_mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + coalesced_history_metadata: Dict[str, torch.Tensor] = {} + history_last_access_iter = history_metadata["last_access_iter"] + if additional_ids is not None: + history_last_access_iter = torch.cat( + [ + history_last_access_iter, + torch.full_like(additional_ids, current_iter), + ] + ) + coalesced_history_metadata["last_access_iter"] = torch.zeros_like( + unique_ids_counts + ).scatter_reduce_( + 0, + unique_inverse_mapping, + history_last_access_iter, + reduce="amax", + include_self=False, + ) + if threshold_mask is not None: + coalesced_history_metadata["last_access_iter"] = coalesced_history_metadata[ + "last_access_iter" + ][threshold_mask] + return coalesced_history_metadata + + def update_metadata_and_generate_eviction_scores( + self, + current_iter: int, + mch_size: int, + coalesced_history_argsort_mapping: torch.Tensor, + coalesced_history_sorted_unique_ids_counts: torch.Tensor, + coalesced_history_mch_matching_elements_mask: torch.Tensor, + coalesced_history_mch_matching_indices: torch.Tensor, + mch_metadata: Dict[str, torch.Tensor], + coalesced_history_metadata: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + mch_last_access_iter = mch_metadata["last_access_iter"] + + # sort coalesced history metadata + coalesced_history_metadata["last_access_iter"].copy_( + coalesced_history_metadata["last_access_iter"][ + coalesced_history_argsort_mapping + ] + ) + coalesced_history_sorted_uniq_ids_last_access_iter = coalesced_history_metadata[ + "last_access_iter" + ] + + # update metadata for matching ids + mch_last_access_iter[coalesced_history_mch_matching_indices] = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + coalesced_history_mch_matching_elements_mask + ] + ) + + # incoming non-matching ids + new_sorted_uniq_ids_last_access = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + ~coalesced_history_mch_matching_elements_mask + ] + ) + + # TODO: find cleaner way to avoid last element of zch + mch_last_access_iter[mch_size - 1] = current_iter + merged_access_iter = torch.cat( + [ + mch_last_access_iter, + new_sorted_uniq_ids_last_access, + ] + ) + # lower scores are evicted first. + merged_eviction_scores = torch.neg( + torch.pow( + current_iter - merged_access_iter + 1, + self._decay_exponent, + ) + ) + + # calculate evicted and replacement indices + ( + evicted_indices, + selected_new_indices, + ) = self._compute_selected_eviction_and_replacement_indices( + mch_size, + merged_eviction_scores, + ) + + mch_last_access_iter[evicted_indices] = new_sorted_uniq_ids_last_access[ + selected_new_indices + ] + + return evicted_indices, selected_new_indices + + +class DistanceLFU_EvictionPolicy(MCHEvictionPolicy): + def __init__( + self, + decay_exponent: float = 1.0, + threshold_filtering_func: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]] + ] = None, # experimental + ) -> None: + super().__init__( + metadata_info=[ + MCHEvictionPolicyMetadataInfo( + metadata_name="counts", + is_mch_metadata=True, + is_history_metadata=False, + ), + MCHEvictionPolicyMetadataInfo( + metadata_name="last_access_iter", + is_mch_metadata=True, + is_history_metadata=True, + ), + ], + threshold_filtering_func=threshold_filtering_func, + ) + self._decay_exponent = decay_exponent + + @property + def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]: + return self._metadata_info + + def record_history_metadata( + self, + current_iter: int, + incoming_ids: torch.Tensor, + history_metadata: Dict[str, torch.Tensor], + ) -> None: + history_last_access_iter = history_metadata["last_access_iter"] + history_last_access_iter[:] = current_iter + + def coalesce_history_metadata( + self, + current_iter: int, + history_metadata: Dict[str, torch.Tensor], + unique_ids_counts: torch.Tensor, + unique_inverse_mapping: torch.Tensor, + additional_ids: Optional[torch.Tensor] = None, + threshold_mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + coalesced_history_metadata: Dict[str, torch.Tensor] = {} + history_last_access_iter = history_metadata["last_access_iter"] + if additional_ids is not None: + history_last_access_iter = torch.cat( + [ + history_last_access_iter, + torch.full_like(additional_ids, current_iter), + ] + ) + coalesced_history_metadata["last_access_iter"] = torch.zeros_like( + unique_ids_counts + ).scatter_reduce_( + 0, + unique_inverse_mapping, + history_last_access_iter, + reduce="amax", + include_self=False, + ) + if threshold_mask is not None: + coalesced_history_metadata["last_access_iter"] = coalesced_history_metadata[ + "last_access_iter" + ][threshold_mask] + return coalesced_history_metadata + + def update_metadata_and_generate_eviction_scores( + self, + current_iter: int, + mch_size: int, + coalesced_history_argsort_mapping: torch.Tensor, + coalesced_history_sorted_unique_ids_counts: torch.Tensor, + coalesced_history_mch_matching_elements_mask: torch.Tensor, + coalesced_history_mch_matching_indices: torch.Tensor, + mch_metadata: Dict[str, torch.Tensor], + coalesced_history_metadata: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + mch_counts = mch_metadata["counts"] + mch_last_access_iter = mch_metadata["last_access_iter"] + + # sort coalesced history metadata + coalesced_history_metadata["last_access_iter"].copy_( + coalesced_history_metadata["last_access_iter"][ + coalesced_history_argsort_mapping + ] + ) + coalesced_history_sorted_uniq_ids_last_access_iter = coalesced_history_metadata[ + "last_access_iter" + ] + + # update metadata for matching ids + mch_counts[ + coalesced_history_mch_matching_indices + ] += coalesced_history_sorted_unique_ids_counts[ + coalesced_history_mch_matching_elements_mask + ] + mch_last_access_iter[coalesced_history_mch_matching_indices] = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + coalesced_history_mch_matching_elements_mask + ] + ) + + # incoming non-matching ids + new_sorted_uniq_ids_counts = coalesced_history_sorted_unique_ids_counts[ + ~coalesced_history_mch_matching_elements_mask + ] + new_sorted_uniq_ids_last_access = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + ~coalesced_history_mch_matching_elements_mask + ] + ) + + # TODO: find cleaner way to avoid last element of zch + mch_counts[mch_size - 1] = torch.iinfo(torch.int64).max + mch_last_access_iter[mch_size - 1] = current_iter + + merged_counts = torch.cat( + [ + mch_counts, + new_sorted_uniq_ids_counts, + ] + ) + merged_access_iter = torch.cat( + [ + mch_last_access_iter, + new_sorted_uniq_ids_last_access, + ] + ) + merged_weighted_distance = torch.pow( + current_iter - merged_access_iter + 1, + self._decay_exponent, + ) + # merged eviction scores are the eviction scores calculated for the + # tensor torch.cat[_mch_sorted_raw_ids, frequency_sorted_uniq_ids[~matching_eles]] + # lower scores are evicted first. + merged_eviction_scores = torch.div(merged_counts, merged_weighted_distance) + + # calculate evicted and replacement indices + ( + evicted_indices, + selected_new_indices, + ) = self._compute_selected_eviction_and_replacement_indices( + mch_size, + merged_eviction_scores, + ) + + # update metadata for evicted ids + mch_counts[evicted_indices] = new_sorted_uniq_ids_counts[selected_new_indices] + + mch_last_access_iter[evicted_indices] = new_sorted_uniq_ids_last_access[ + selected_new_indices + ] + + return evicted_indices, selected_new_indices + + +@torch.fx.wrap +def _mch_remap( + features: Dict[str, JaggedTensor], + mch_sorted_raw_ids: torch.Tensor, + mch_remapped_ids_mapping: torch.Tensor, + zch_index: int, +) -> Dict[str, JaggedTensor]: + """Remap feature ids to zch ids, TODO: create a custom kernel""" + remapped_features: Dict[str, JaggedTensor] = {} + for name, feature in features.items(): + values = feature.values() + remapped_ids = torch.empty_like(values) + + # compute overlap between incoming IDs and remapping table + searched_indices = torch.searchsorted(mch_sorted_raw_ids[:-1], values) + retrieved_indices = mch_sorted_raw_ids[searched_indices] + # identify matching inputs IDs + matching_indices = retrieved_indices == values + # update output with remapped matching IDs + remapped_ids[matching_indices] = mch_remapped_ids_mapping[ + searched_indices[matching_indices] + ] + # default embedding for non-matching ids + remapped_ids[~matching_indices] = zch_index + + remapped_features[name] = JaggedTensor( + values=remapped_ids, + lengths=feature.lengths(), + offsets=feature.offsets(), + weights=feature.weights_or_none(), + ) + return remapped_features + + +class MCHManagedCollisionModule(ManagedCollisionModule): + """ + ZCH managed collision module + + Args: + zch_size (int): range of output ids, within [output_size_offset, output_size_offset + zch_size - 1) + device (torch.device): device on which this module will be executed + eviction_policy (eviction policy): eviction policy to be used + eviction_interval (int): interval of eviction policy is triggered + input_hash_size (int): input feature id range, will be passed to input_hash_func as second arg + input_hash_func (Optional[Callable]): function used to generate hashes for input features. This function is typically used to drive uniform distribution over range same or greater than input data + mch_size (Optional[int]): DEPRECIATED - size of residual output (ie. legacy MCH), experimental feature. Ids are internally shifted by output_size_offset + zch_output_range + mch_hash_func (Optional[Callable]): DEPRECIATED - function used to generate hashes for residual feature. will hash down to mch_size. + output_global_offset (int): offset of the output id for output range, typically only used in sharding applications. + """ + + def __init__( + self, + zch_size: int, + device: torch.device, + eviction_policy: MCHEvictionPolicy, + eviction_interval: int, + input_hash_size: int = (2**63) - 1, + input_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None, + mch_size: Optional[int] = None, + mch_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None, + name: Optional[str] = None, + output_global_offset: int = 0, # typically not provided by user + output_segments: Optional[List[int]] = None, # typically not provided by user + buckets: int = 1, + ) -> None: + if output_segments is None: + output_segments = [output_global_offset, output_global_offset + zch_size] + super().__init__( + device=device, + output_segments=output_segments, + ) + if mch_size is not None or mch_hash_func is not None: + logger.warning( + "co-locating a hash table for missing ids is depreciated (ie. mch_size, mch_hash_func), values will be ignored" + ) + self._init_output_segments_tensor: torch.Tensor = self._output_segments_tensor + self._name = name + self._input_history_buffer_size: int = -1 + self._input_hash_size = input_hash_size + self._zch_size: int = zch_size + assert self._zch_size > 0, "zch_size must be > 0" + self._output_global_offset: int = output_global_offset + self._input_hash_func = input_hash_func + + self._eviction_interval = eviction_interval + assert self._eviction_interval > 0, "eviction_interval must be > 1" + self._eviction_policy = eviction_policy + + self._current_iter: int = -1 + self._buckets = buckets + self._init_buffers() + + ## ------ history info ------ + self._mch_metadata: Dict[str, torch.Tensor] = {} + self._history_metadata: Dict[str, torch.Tensor] = {} + self._init_metadata_buffers() + self._current_history_buffer_offset: int = 0 + + self._evicted: bool = False + self._last_eviction_iter: int = -1 + + def _init_buffers(self) -> None: + self.register_buffer( + "_mch_sorted_raw_ids", + torch.full( + (self._zch_size,), + torch.iinfo(torch.int64).max, + dtype=torch.int64, + device=self.device, + ), + ) + self.register_buffer( + "_mch_slots", + torch.tensor( + [(self._zch_size - 1)], + dtype=torch.int64, + device=self.device, + ), + persistent=False, + ) + self.register_buffer( + "_delimiter", + torch.tensor( + [torch.iinfo(torch.int64).max], dtype=torch.int64, device=self.device + ), + persistent=False, + ) + self.register_buffer( + "_mch_remapped_ids_mapping", + torch.arange( + start=self._output_global_offset, + end=self._output_global_offset + self._zch_size, + dtype=torch.int64, + device=self.device, + ), + ) + + self._evicted_emb_indices: torch.Tensor = torch.empty((1,), device=self.device) + + def _init_metadata_buffers(self) -> None: + eviction_metadata_info = self._eviction_policy.metadata_info + for metadata in eviction_metadata_info: + metadata_name, is_mch_metadata, is_history_metadata = metadata + # mch_metadata + if is_mch_metadata: + buffer_name = "_mch_" + metadata_name + self.register_buffer( + buffer_name, + torch.zeros( + (self._zch_size,), + dtype=torch.int64, + device=self.device, + ), + ) + self._mch_metadata[metadata_name] = getattr(self, buffer_name) + + def _init_history_buffers(self, features: Dict[str, JaggedTensor]) -> None: + input_batch_value_size_cumsum = 0 + for _, feature in features.items(): + input_batch_value_size_cumsum += feature.values().numel() + self._input_history_buffer_size = int( + input_batch_value_size_cumsum * self._eviction_interval * 1.25 + ) + # pyre-fixme[16]: `MCHManagedCollisionModule` has no attribute + # `_history_accumulator`. + self._history_accumulator: torch.Tensor = torch.empty( + self._input_history_buffer_size, + dtype=torch.int64, + device=self.device, + ) + eviction_metadata_info = self._eviction_policy.metadata_info + for metadata in eviction_metadata_info: + metadata_name, is_mch_metadata, is_history_metadata = metadata + # history_metadata + if is_history_metadata: + buffer_name = "_history_" + metadata_name + self.register_buffer( + buffer_name, + torch.zeros( + self._input_history_buffer_size, + dtype=torch.int64, + device=self.device, + ), + persistent=False, + ) + self._history_metadata[metadata_name] = getattr(self, buffer_name) + + def preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + if self._input_hash_func is None: + return features + preprocessed_features: Dict[str, JaggedTensor] = {} + for name, feature in features.items(): + preprocessed_features[name] = JaggedTensor( + # pyre-ignore [29] + values=self._input_hash_func(feature.values(), self._input_hash_size), + lengths=feature.lengths(), + offsets=feature.offsets(), + weights=feature.weights_or_none(), + ) + return preprocessed_features + + def reset_inference_mode(self) -> None: + self._evicted = False + self._last_eviction_iter = -1 + + @torch.no_grad() + def _match_indices( + self, sorted_sequence: torch.Tensor, search_values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + searched_indices = torch.searchsorted(sorted_sequence[:-1], search_values) + retrieved_ids = sorted_sequence[searched_indices] + matching_eles = retrieved_ids == search_values + matched_indices = searched_indices[matching_eles] + return (matching_eles, matched_indices) + + @torch.no_grad() + def _sort_mch_buffers(self) -> None: + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Module, + # Tensor]`. + argsorted_sorted_raw_ids = torch.argsort(self._mch_sorted_raw_ids, stable=True) + # pyre-fixme[29]: `Union[(self: TensorBase, src: Tensor, non_blocking: bool + # = ...) -> Tensor, Module, Tensor]` is not a function. + self._mch_sorted_raw_ids.copy_( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self._mch_sorted_raw_ids[argsorted_sorted_raw_ids] + ) + # pyre-fixme[29]: `Union[(self: TensorBase, src: Tensor, non_blocking: bool + # = ...) -> Tensor, Module, Tensor]` is not a function. + self._mch_remapped_ids_mapping.copy_( + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self._mch_remapped_ids_mapping[argsorted_sorted_raw_ids] + ) + for mch_metadata_buffer in self._mch_metadata.values(): + mch_metadata_buffer.copy_(mch_metadata_buffer[argsorted_sorted_raw_ids]) + + @torch.no_grad() + def _update_and_evict( + self, + uniq_ids: torch.Tensor, + uniq_ids_counts: torch.Tensor, + uniq_ids_metadata: Dict[str, torch.Tensor], + ) -> None: + argsorted_uniq_ids_counts = torch.argsort( + uniq_ids_counts, descending=True, stable=True + ) + frequency_sorted_uniq_ids = uniq_ids[argsorted_uniq_ids_counts] + frequency_sorted_uniq_ids_counts = uniq_ids_counts[argsorted_uniq_ids_counts] + + matching_eles, matched_indices = self._match_indices( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + self._mch_sorted_raw_ids, + frequency_sorted_uniq_ids, + ) + + new_frequency_sorted_uniq_ids = frequency_sorted_uniq_ids[~matching_eles] + + # evicted_indices are indices in the mch map to be evicted, and + # selected_new_indices are the indices of the ids in the coalesced + # history that are to be added to the mch. + ( + evicted_indices, + selected_new_indices, + ) = self._eviction_policy.update_metadata_and_generate_eviction_scores( + self._current_iter, + self._zch_size, + argsorted_uniq_ids_counts, + frequency_sorted_uniq_ids_counts, + matching_eles, + matched_indices, + self._mch_metadata, + uniq_ids_metadata, + ) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + self._mch_sorted_raw_ids[evicted_indices] = new_frequency_sorted_uniq_ids[ + selected_new_indices + ] + + # NOTE evicted ids for emb reset + # if evicted flag is already set, then existing evicted ids havent been + # consumed by evict(). append new evicted ids to the list + if self._evicted: + self._evicted_emb_indices = torch.unique( + torch.cat( + [ + self._evicted_emb_indices, + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[No... + self._mch_remapped_ids_mapping[evicted_indices], + ] + ) + ) + else: + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self._evicted_emb_indices = self._mch_remapped_ids_mapping[evicted_indices] + self._evicted = True + + # re-sort for next search + self._sort_mch_buffers() + + @torch.no_grad() + def _coalesce_history(self) -> None: + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + current_history_accumulator = self._history_accumulator[ + : self._current_history_buffer_offset + ] + uniq_ids, uniq_inverse_mapping, uniq_ids_counts = torch.unique( + current_history_accumulator, + return_inverse=True, + return_counts=True, + ) + if self._eviction_policy._threshold_filtering_func is not None: + threshold_mask, threshold = self._eviction_policy._threshold_filtering_func( + uniq_ids_counts + ) + else: + threshold_mask = None + + coalesced_eviction_history_metadata = ( + self._eviction_policy.coalesce_history_metadata( + self._current_iter, + { + metadata_name: metadata_buffer[ + : self._current_history_buffer_offset + ] + for metadata_name, metadata_buffer in self._history_metadata.items() + }, + uniq_ids_counts, + uniq_inverse_mapping, + threshold_mask=threshold_mask, + ) + ) + if threshold_mask is not None: + uniq_ids = uniq_ids[threshold_mask] + uniq_ids_counts = uniq_ids_counts[threshold_mask] + self._update_and_evict( + uniq_ids, uniq_ids_counts, coalesced_eviction_history_metadata + ) + # reset buffer offset + self._current_history_buffer_offset = 0 + + def profile( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + if not self.training: + return features + + if self._current_iter == -1: + self._current_iter = int(self._current_iter_tensor.item()) + self._last_eviction_iter = self._current_iter + self._current_iter += 1 + self._current_iter_tensor.data += 1 + + # init history buffers if needed + if self._input_history_buffer_size == -1: + self._init_history_buffers(features) + + for _, feature in features.items(): + values = feature.values() + free_elements = ( + self._input_history_buffer_size - self._current_history_buffer_offset + ) + values = values[:free_elements] + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS... + self._history_accumulator[ + self._current_history_buffer_offset : self._current_history_buffer_offset + + values.shape[0] + ] = values + self._eviction_policy.record_history_metadata( + self._current_iter, + values, + { + metadata_name: metadata_buffer[ + self._current_history_buffer_offset : self._current_history_buffer_offset + + values.shape[0] + ] + for metadata_name, metadata_buffer in self._history_metadata.items() + }, + ) + self._current_history_buffer_offset += values.shape[0] + + # coalesce history / evict + if self._current_iter - self._last_eviction_iter == self._eviction_interval: + self._coalesce_history() + self._last_eviction_iter = self._current_iter + + return features + + def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + return _mch_remap( + features, + self._mch_sorted_raw_ids, + self._mch_remapped_ids_mapping, + self._output_global_offset + self._zch_size - 1, + ) + + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + """ + Args: + feature (JaggedTensor]): feature representation + Returns: + Dict[str, JaggedTensor]: modified JT + """ + + features = self.preprocess(features) + features = self.profile(features) + return self.remap(features) + + def output_size(self) -> int: + return self._zch_size + + def buckets(self) -> int: + return self._buckets + + def input_size(self) -> int: + return self._input_hash_size + + def open_slots(self) -> torch.Tensor: + # pyre-fixme[29]: `Union[(self: TensorBase, other: Any) -> Tensor, Module, + # Tensor]` is not a function. + return self._mch_slots - torch.searchsorted( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + self._mch_sorted_raw_ids, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + self._delimiter, + ) + + @torch.no_grad() + def evict(self) -> Optional[torch.Tensor]: + if self._evicted: + self._evicted = False + return self._evicted_emb_indices + else: + return None + + def validate_state(self) -> None: + start = self._output_global_offset + end = start + self._zch_size + assert ( + start in self._output_segments_tensor + and end in self._output_segments_tensor + ), f"shard within range [{start}, {end}] cannot be built out of segements {self._output_segments_tensor}" + + # update output segments and resort + self._output_segments_tensor = self._init_output_segments_tensor + self._sort_mch_buffers() + + def rebuild_with_output_id_range( + self, + output_id_range: Tuple[int, int], + output_segments: List[int], + device: Optional[torch.device] = None, + ) -> "MCHManagedCollisionModule": + + new_zch_size = output_id_range[1] - output_id_range[0] + + return type(self)( + name=self._name, + zch_size=new_zch_size, + device=device or self.device, + eviction_policy=self._eviction_policy, + eviction_interval=self._eviction_interval, + input_hash_size=self._input_hash_size, + input_hash_func=self._input_hash_func, + output_global_offset=output_id_range[0], + output_segments=output_segments, + buckets=len(output_segments) - 1, + ) diff --git a/torchrec/modules/mlp.py b/torchrec/modules/mlp.py index c369b24c3..c9a672bb2 100644 --- a/torchrec/modules/mlp.py +++ b/torchrec/modules/mlp.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Callable, List, Optional, Union import torch @@ -50,13 +52,18 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") self._out_size = out_size self._in_size = in_size self._linear: nn.Linear = nn.Linear( - self._in_size, self._out_size, bias=bias, device=device + self._in_size, + self._out_size, + bias=bias, + device=device, + dtype=dtype, ) self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation @@ -120,6 +127,7 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() @@ -137,6 +145,7 @@ def __init__( bias=bias, activation=extract_module_or_tensor_callable(activation), device=device, + dtype=dtype, ) for i in range(len(layer_sizes)) ] diff --git a/torchrec/modules/object_pool.py b/torchrec/modules/object_pool.py new file mode 100644 index 000000000..00963f35d --- /dev/null +++ b/torchrec/modules/object_pool.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import abc +from typing import Generic, TypeVar + +import torch + +T = TypeVar("T") + + +class ObjectPool(abc.ABC, torch.nn.Module, Generic[T]): + """ + Interface for TensorPool and KeyedJaggedTensorPool + + Defines methods for lookup, update and obtaining pool size + """ + + @abc.abstractmethod + def lookup(self, ids: torch.Tensor) -> T: + pass + + @abc.abstractmethod + def update(self, ids: torch.Tensor, values: T) -> None: + pass + + @abc.abstractproperty + def pool_size(self) -> int: + pass diff --git a/torchrec/modules/object_pool_lookups.py b/torchrec/modules/object_pool_lookups.py new file mode 100644 index 000000000..b30358f19 --- /dev/null +++ b/torchrec/modules/object_pool_lookups.py @@ -0,0 +1,762 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import abc +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + EmbeddingLocation, + PoolingMode, + SparseType, + SplitTableBatchedEmbeddingBagsCodegen, +) +from torch.autograd.profiler import record_function +from torchrec.modules.utils import jagged_index_select_with_empty +from torchrec.sparse.jagged_tensor import JaggedTensor + + +torch.fx.wrap("jagged_index_select_with_empty") + + +class KeyedJaggedTensorPoolLookup(abc.ABC, torch.nn.Module): + """ + Abstract base class for KeyedJaggedTensor pool lookups + + Implementations of this class should define methods for + - lookup using ids + - update values associated with ids + - returning states that should be saved + and loaded in state_dict() + + Args: + pool_size (int): size of the pool + feature_max_lengths (Dict[str,int]): Dict mapping feature name to max length that + its values can have for any given batch. The underlying storage representation + for the KJT pool is currently a padded 2D tensor, so this information is + needed. + is_weighted (bool): Boolean indicating whether or not the KJTs will have weights + that need to be stored separately. + device (torch.device): device that KJTs should be placed on + + Example: + Other classes should inherit from this class and implement the + abstract methods. + """ + + _pool_size: int + _feature_max_lengths: Dict[str, int] + _is_weighted: bool + _total_lengths: int + _total_lengths_t: torch.Tensor + _key_lengths: torch.Tensor + _jagged_lengths: torch.Tensor + _jagged_offsets: torch.Tensor + _device: torch.device + + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + is_weighted: bool, + device: torch.device, + ) -> None: + super().__init__() + self._pool_size = pool_size + self._feature_max_lengths = feature_max_lengths + self._device = device + self._total_lengths = sum(self._feature_max_lengths.values()) + self._total_lengths_t = torch.tensor( + [self._total_lengths], device=device, dtype=torch.int32 + ) + self._is_weighted = is_weighted + + self._key_lengths = torch.zeros( + (self._pool_size, len(self._feature_max_lengths)), + dtype=torch.int32, + device=self._device, + ) + + lengths, offsets = self._infer_jagged_lengths_inclusive_offsets() + self._jagged_lengths = lengths + self._jagged_offsets = offsets + + def _infer_jagged_lengths_inclusive_offsets( + self, + ) -> Tuple[torch.Tensor, torch.Tensor]: + lengths_sum = self._key_lengths.sum(dim=1) + padded_lengths = self._total_lengths_t - lengths_sum + jagged_lengths = torch.stack([lengths_sum, padded_lengths], dim=1).flatten() + return ( + jagged_lengths, + torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths), + ) + + def _load_from_state_dict( + self, + state_dict: Dict[str, torch.Tensor], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """ + Override _load_from_state_dict in torch.nn.Module. + """ + torch.nn.Module._load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + lengths, offsets = self._infer_jagged_lengths_inclusive_offsets() + self._jagged_lengths = lengths + self._jagged_offsets = offsets + + @abc.abstractmethod + def lookup(self, ids: torch.Tensor) -> JaggedTensor: + pass + + @abc.abstractmethod + def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: + # assume that at this point there are no duplicate ids, and all preproc is done by KJTPool + pass + + def forward(self, ids: torch.Tensor) -> JaggedTensor: + """ + Forward performs a lookup using the given ids + + Args: + ids (torch.Tensor): Tensor of IDs to lookup + + Returns: + JaggedTensor: JaggedTensor containing the merged + values, lengths and weights associated with the ids + for all the features of the KJT pool. + """ + return self.lookup(ids) + + @abc.abstractmethod + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + pass + + +class TensorJaggedIndexSelectLookup(KeyedJaggedTensorPoolLookup): + _values_dtype: torch.dtype + _values: torch.Tensor + _weights: torch.Tensor + + def __init__( + self, + pool_size: int, + values_dtype: torch.dtype, + feature_max_lengths: Dict[str, int], + is_weighted: bool, + device: torch.device, + ) -> None: + super().__init__( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + device=device, + is_weighted=is_weighted, + ) + + self._values_dtype = values_dtype + + self._values = torch.zeros( + (self._pool_size, self._total_lengths), + dtype=self._values_dtype, + device=self._device, + ) + + if self._is_weighted: + self._weights = torch.zeros( + (self._pool_size, self._total_lengths), + dtype=torch.float, + device=self._device, + ) + else: + # to appease torchscript + self._weights = torch.empty((0,), dtype=torch.float, device=self._device) + + def lookup(self, ids: torch.Tensor) -> JaggedTensor: + """ + Example: + memory layout is + values = [ + [1], [2, 2], 0, 0, 0,0 + [11,11],[12,12,12],0,0 + ] + lengths = [ + 1,2 + 2,3 + ] + + We can consider this as a jagged tensor with + [ + [1,2,2], [0,0,0,0], + [11,11,12,12,12],[0,0] + ] + where we can combine all values together, and all padded values together. + The index to select into is then 2*ids (doubled because of padding index). + + jagged_index_select will let us retrieve + [1,2,2,11,11,12,12,12], that we can then massage into + [ + [1], [2,2] + [11,11] [12,12,12] + ] + + Later (not in this method), we turn this into appropriate KJT format, + using jagged index select to transpose into + [ + [1] [11, 11] + [2,2] [12,12,12] + ] + """ + + with record_function("## KJTPool Lookup ##"): + key_lengths_for_ids = self._key_lengths[ids] + lookup_indices = 2 * ids + lengths = self._jagged_lengths[lookup_indices] + offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths) + values = jagged_index_select_with_empty( + self._values.flatten().unsqueeze(-1), + lookup_indices, + self._jagged_offsets, + offsets, + ) + weights = torch.jit.annotate(Optional[torch.Tensor], None) + if self._is_weighted: + weights = jagged_index_select_with_empty( + self._weights.flatten().unsqueeze(-1), + lookup_indices, + self._jagged_offsets, + offsets, + ) + + return JaggedTensor( + values=values, weights=weights, lengths=key_lengths_for_ids.flatten() + ) + + def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: + + with record_function("## TensorPool update ##"): + key_lengths = ( + # pyre-ignore + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .sum(axis=1) + ) + key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) + + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values.values(), + [key_offsets], + [self._total_lengths], + 0, + ) + + self._values[ids] = padded_values.to(self._values.dtype) + self._key_lengths[ids] = ( + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .to(self._key_lengths.dtype) + ) + + if values.weights_or_none() is not None: + padded_weights = torch.ops.fbgemm.jagged_to_padded_dense( + values.weights(), + [key_offsets], + [self._total_lengths], + 0, + ) + self._weights[ids] = padded_weights + + lengths, offsets = self._infer_jagged_lengths_inclusive_offsets() + self._jagged_lengths = lengths + self._jagged_offsets = offsets + + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield "values", self._values + yield "key_lengths", self._key_lengths + if self._is_weighted: + yield "weights", self._weights + + +class UVMCachingInt64Lookup(KeyedJaggedTensorPoolLookup): + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + is_weighted: bool, + device: torch.device, + ) -> None: + super().__init__( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + device=device, + is_weighted=is_weighted, + ) + + # memory layout will be + # [f1 upper bits][f2 upper bits][upper bits paddings][f1 lower bits][f2 lower bits][lower bits paddings] + + # TBE requires dim to be divisible by 4 + self._bit_dims: int = ((self._total_lengths + 4 - 1) // 4) * 4 + + self._bit_dims_t: torch.Tensor = torch.tensor( + [self._bit_dims], dtype=torch.int32, device=self._device + ) + + self._tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + pool_size, + 2 * self._bit_dims, + EmbeddingLocation.MANAGED, + ComputeDevice.CUDA, + ), + ], + pooling_mode=PoolingMode.NONE, + device=device, + ) + self._tbe_state: torch.Tensor = ( + self._tbe.split_embedding_weights()[0].flatten().view(pool_size, -1) + ) + + if self._is_weighted: + # pyre-ignore + self._tbe_weights = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + pool_size, + self._bit_dims, + ( + EmbeddingLocation.MANAGED_CACHING + if device != torch.device("meta") + else EmbeddingLocation.MANAGED + ), + ComputeDevice.CUDA, + ), + ], + pooling_mode=PoolingMode.NONE, + device=device, + ) + self._tbe_weights_state: torch.Tensor = ( + self._tbe_weights.split_embedding_weights()[0] + .flatten() + .view(pool_size, -1) + ) + + def lookup(self, ids: torch.Tensor) -> JaggedTensor: + with record_function("## UVMCachingInt64Lookup lookup ##"): + output = self._tbe( + indices=ids, + offsets=torch.tensor([0, ids.shape[0]], device=self._device), + ) + + output_int_split = output.view(torch.int32).split( + [self._bit_dims, self._bit_dims], dim=1 + ) + output_int_upper = output_int_split[0].to(torch.int64) << 32 + output_int_lower = output_int_split[1].to(torch.int64) & 0xFFFFFFFF + + kjt_dense_values = output_int_upper | output_int_lower + + key_lengths_for_ids = self._key_lengths[ids] + lengths_sum = key_lengths_for_ids.sum(dim=1) + + padded_lengths = self._bit_dims_t - lengths_sum + # TODO: pre-compute this on class init + jagged_lengths = torch.stack( + [ + lengths_sum, + padded_lengths, + ], + dim=1, + ).flatten() + + lookup_indices = torch.arange(0, ids.shape[0] * 2, 2, device=self._device) + output_lengths = jagged_lengths[lookup_indices] + values = jagged_index_select_with_empty( + kjt_dense_values.flatten().unsqueeze(-1), + lookup_indices, + torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths), + torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths), + ) + + return JaggedTensor( + values=values.flatten(), + lengths=key_lengths_for_ids.flatten(), + ) + + def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: + with record_function("## UVMCachingInt64Lookup update ##"): + key_lengths = ( + # pyre-ignore + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .sum(axis=1) + ) + key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values.values(), + [key_offsets], + [self._bit_dims], + 0, + ) + + values_upper_bits = (padded_values >> 32).to(torch.int32) + values_lower_bits = (padded_values & 0xFFFFFFFF).to(torch.int32) + + state = torch.cat([values_upper_bits, values_lower_bits], dim=1).view( + torch.float32 + ) + + self._tbe_state[ids] = state + + self._key_lengths[ids] = ( + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .to(self._key_lengths.dtype) + ) + + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield "values_upper_and_lower_bits", self._tbe_state + if self._is_weighted: + yield "weights", self._tbe_weights_state + + +class UVMCachingInt32Lookup(KeyedJaggedTensorPoolLookup): + def __init__( + self, + pool_size: int, + feature_max_lengths: Dict[str, int], + is_weighted: bool, + device: torch.device, + ) -> None: + super().__init__( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + device=device, + is_weighted=is_weighted, + ) + + # memory layout will be + # f1 f2 + # [f1 bits] [f2 bits] padding + # TBE requires dim to be divisible by 4. + self._bit_dims: int = ((self._total_lengths + 4 - 1) // 4) * 4 + self._bit_dims_t: torch.Tensor = torch.tensor( + [self._bit_dims], dtype=torch.int32, device=self._device + ) + + self._tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + pool_size, + self._bit_dims, + ( + EmbeddingLocation.MANAGED_CACHING + if self._device.type != "meta" + else EmbeddingLocation.DEVICE + ), + ComputeDevice.CUDA, + ), + ], + pooling_mode=PoolingMode.NONE, + device=device, + ) + + self._tbe_state: torch.Tensor = ( + self._tbe.split_embedding_weights()[0].flatten().view(pool_size, -1) + ) + + if self._is_weighted: + # pyre-ignore + self._tbe_weights = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + pool_size, + self._bit_dims, + ( + EmbeddingLocation.MANAGED_CACHING + if self._device.type != "meta" + else EmbeddingLocation.DEVICE + ), + ComputeDevice.CUDA, + ), + ], + pooling_mode=PoolingMode.NONE, + device=device, + ) + self._tbe_weights_state: torch.Tensor = ( + self._tbe_weights.split_embedding_weights()[0] + .flatten() + .view(pool_size, -1) + ) + + def lookup(self, ids: torch.Tensor) -> JaggedTensor: + with record_function("## UVMCachingInt32Lookup lookup ##"): + output = self._tbe( + indices=ids, + offsets=torch.tensor([0, ids.shape[0]], device=self._device), + ) + + kjt_dense_values = output.view(torch.int32) + + key_lengths_for_ids = self._key_lengths[ids] + lengths_sum = key_lengths_for_ids.sum(dim=1) + + padded_lengths = self._bit_dims_t - lengths_sum + jagged_lengths = torch.stack( + [ + lengths_sum, + padded_lengths, + ], + dim=1, + ).flatten() + + lookup_ids = 2 * torch.arange(ids.shape[0], device=self._device) + output_lengths = jagged_lengths[lookup_ids] + values = jagged_index_select_with_empty( + kjt_dense_values.flatten().unsqueeze(-1), + lookup_ids, + torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths), + torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths), + ) + + return JaggedTensor( + values=values.flatten(), + lengths=key_lengths_for_ids.flatten(), + ) + + def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: + with record_function("## UVMCachingInt32Lookup update##"): + key_lengths = ( + # pyre-ignore + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .sum(axis=1) + ) + key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) + state = torch.ops.fbgemm.jagged_to_padded_dense( + values.values(), + [key_offsets], + [self._bit_dims], + 0, + ).view(torch.float32) + + self._tbe_state[ids] = state + + self._key_lengths[ids] = ( + values.lengths() + .view(-1, len(self._feature_max_lengths)) + .to(self._key_lengths.dtype) + ) + + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield "values", self._tbe_state + if self._is_weighted: + yield "weights", self._tbe_weights_state + + +class TensorPoolLookup(abc.ABC, torch.nn.Module): + """ + Abstract base class for tensor pool lookups + + Implementations of this class should define methods for + - lookup using ids + - update values associated with ids + - returning states that should be saved + and loaded in state_dict() + - setting state from loaded values + + Args: + pool_size (int): size of the pool + dim (int): dimension of the tensors in the pool + dtype (torch.dtype): dtype of the tensors in the pool + device (torch.device): device of the tensors in the pool + + Example: + Other classes should inherit this base class and implement the + abstract methods. + """ + + def __init__( + self, + pool_size: int, + dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self._pool_size = pool_size + self._dim = dim + self._dtype = dtype + self._device = device + + @abc.abstractmethod + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + pass + + @abc.abstractmethod + def update(self, ids: torch.Tensor, values: torch.Tensor) -> None: + # assume that at this point there are no duplicate ids, and all preproc is done by TensorPool + pass + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + """ + Forward performs a lookup using the given ids + + Args: + ids (torch.Tensor): Tensor of IDs to lookup + + Returns: + torch.Tensor: Tensor of values associated with the given ids + """ + return self.lookup(ids) + + @abc.abstractmethod + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + pass + + @abc.abstractmethod + def set_state(self, loaded_values: torch.Tensor) -> None: + pass + + +class TensorLookup(TensorPoolLookup): + def __init__( + self, + pool_size: int, + dim: int, + dtype: torch.dtype, + device: torch.device, + enable_uvm: bool = False, + ) -> None: + super().__init__( + pool_size=pool_size, + dim=dim, + dtype=dtype, + device=device, + ) + + self._enable_uvm = enable_uvm + self._pool: torch.Tensor = ( + torch.zeros( + (self._pool_size, self._dim), + out=torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + (self._pool_size, self._dim), + device=device, + dtype=dtype, + ), + [self._pool_size * self._dim], + False, + ), + ) + if self._enable_uvm + else torch.zeros( + (self._pool_size, self._dim), + dtype=self._dtype, + device=self._device, + ) + ) + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + torch._assert( + ids.device.type == self._device.type, + "ids.device.type does not match self._device.type", + ) + with record_function("## TensorPool Lookup ##"): + ret = self._pool[ids] + return ret + + def update(self, ids: torch.Tensor, values: torch.Tensor) -> None: + with record_function("## TensorPool update ##"): + self._pool[ids] = values + + def set_state(self, loaded_values: torch.Tensor) -> None: + self._pool.copy_(loaded_values) + + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield "_pool", self._pool + + +class UVMCachingFloatLookup(TensorPoolLookup): + def __init__( + self, + pool_size: int, + dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + + super().__init__( + pool_size=pool_size, + dim=dim, + dtype=dtype, + device=device, + ) + + sparse_type = SparseType.from_dtype(self._dtype) + + self._tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + self._pool_size, + self._dim, + ( + EmbeddingLocation.MANAGED_CACHING + if self._device.type != "meta" + else EmbeddingLocation.DEVICE + ), + ComputeDevice.CUDA, + ), + ], + pooling_mode=PoolingMode.NONE, + device=device, + weights_precision=sparse_type, + output_dtype=sparse_type, + ) + + self._tbe_state: torch.Tensor = ( + self._tbe.split_embedding_weights()[0].flatten().view(pool_size, -1) + ) + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + torch._assert( + ids.device.type == self._device.type, + "ids.device.type does not match self._device.type", + ) + with record_function("## UVMCachingFloatLookup lookup ##"): + output = self._tbe( + indices=ids, + offsets=torch.tensor([0, ids.shape[0]], device=self._device), + ) + return output + + def update(self, ids: torch.Tensor, values: torch.Tensor) -> None: + with record_function("## UVMCachingFloatLookup update ##"): + self._tbe_state[ids] = values + + def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: + yield "_pool", self._tbe_state + + def set_state(self, loaded_values: torch.Tensor) -> None: + self._tbe_state.copy_(loaded_values) diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py new file mode 100644 index 000000000..c704e9f90 --- /dev/null +++ b/torchrec/modules/regroup.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torchrec.sparse.jagged_tensor import ( + _desugar_keyed_tensors, + _kt_regroup_arguments, + KeyedTensor, +) +from torchrec.types import CacheMixin + + +@torch.fx.wrap +def _get_kts_values(kts: List[KeyedTensor]) -> List[torch.Tensor]: + return [kt.values() for kt in kts] + + +@torch.fx.wrap +def _permuted_values( + kts: List[KeyedTensor], remap: List[Tuple[int, str]], dim: int +) -> torch.Tensor: + embedding_dicts = [kt.to_dict() for kt in kts] + values = [embedding_dicts[idx][key] for (idx, key) in remap] + return torch.cat(values, dim=dim) + + +@torch.fx.wrap +def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None: + assert len(keyed_tensors) > 0, "Empty list provided" + assert all( + kt.device() == keyed_tensors[0].device() for kt in keyed_tensors + ), "All inputs should be on the same device." + assert all( + kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors + ), "All inputs should have the same key_dim" + module._dim = keyed_tensors[0].key_dim() + + if module._dim == 1: + module._init_fbgemm_regroup(keyed_tensors) + else: + module._init_regroup(keyed_tensors) + module._is_inited = True + + +class PermuteMultiEmbedding(torch.nn.Module): + """ + Module to handle cached tensors and running FBGEMM + op for KT. This separate module allows fx tracing through + all the logic in KTRegroupAsDict while keeping what's necessary + for exposing set_device and allowing tensors to be moved to + the appropriate device during model processing. + + Args: + groups (List[List[str]]): Groups from KTRegroupAsDict + + """ + + def __init__(self, groups: List[List[str]]) -> None: + super().__init__() + self._groups = groups + self.register_buffer("_permutes", torch.empty(0), persistent=False) + self.register_buffer("_in_shapes", torch.empty(0), persistent=False) + self.register_buffer("_out_shapes", torch.empty(0), persistent=False) + self._out_lengths: Optional[List[int]] = None + + def init_tensors( + self, + permute: torch.Tensor, + in_shapes: torch.Tensor, + out_shapes: torch.Tensor, + out_lengths: List[int], + ) -> None: + # no need to pin_memory() or to(..., non_blocking=True) since occurs only once + self._permutes = permute + self._in_shapes = in_shapes + self._out_shapes = out_shapes + self._out_lengths = out_lengths + + @torch.jit.export + def set_device(self, device: str) -> None: + self._permutes = self._permutes.to(device) + self._in_shapes = self._in_shapes.to(device) + self._out_shapes = self._out_shapes.to(device) + + def forward(self, values: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops.fbgemm.permute_multi_embedding( + values, + self._permutes, + self._in_shapes, + self._out_shapes, + self._out_lengths, + ) + + +def _to_tensor_dict( + keys: List[str], values: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]] +) -> Dict[str, torch.Tensor]: + return {key: values[i] for i, key in enumerate(keys)} + + +class KTRegroupAsDict(torch.nn.Module, CacheMixin): + """ + KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict() + + The advantage of using this module it caches the regrouping logic after first batch. + + Args: + groups (List[List[str]]): features per output group + keys (List[str]): key of each output group + + Example:: + + keys = ['object', 'user'] + groups = [['f1', 'f2'], ['f3']] + regroup_module = KTRegroupAsDict(groups, keys) + + + tensor_list = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)] + kts = [KeyedTensor.from_tensor_list(['f1', 'f2', 'f3' ], tensor_list)] + out = regroup_module(kts) + + """ + + def __init__(self, groups: List[List[str]], keys: List[str]) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + assert len(groups) == len(keys), "Groups and keys should have same length" + self._groups = groups + self._keys = keys + self._is_inited = False + + # cached values populated on first forward call + self._dim: int = 1 + self._use_fbgemm_regroup: bool = False + self._splits: List[int] = [] + self._idx_key_pairs: List[Tuple[int, str]] = [] + self._permute_pooled_embs_impl = PermuteMultiEmbedding(groups) + + def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None: + self._use_fbgemm_regroup = True + keys, lengths, values = _desugar_keyed_tensors(kts) + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], + keys, + lengths, + self._groups, + ) + # no need to pin_memory() or to(..., non_blocking=True) since occurs only once + self._permute_pooled_embs_impl.init_tensors( + permutes, + in_shapes, + out_shapes, + out_lengths, + ) + + def _init_regroup(self, kts: List[KeyedTensor]) -> None: + lengths = [kt.length_per_key() for kt in kts] + indices = [kt._key_indices() for kt in kts] + + key_to_idx: dict[str, int] = {} + for i, kt in enumerate(kts): + for key in kt.keys(): + if key in key_to_idx: + raise RuntimeError( + f"Duplicate key {key} found in KeyedTensors, undefined behavior" + ) + key_to_idx[key] = i + + splits: List[int] = [] + idx_key_pairs: List[Tuple[int, str]] = [] + for group in self._groups: + group_length = 0 + for name in group: + idx_key_pairs.append((key_to_idx[name], name)) + group_length += lengths[key_to_idx[name]][ + indices[key_to_idx[name]][name] + ] + splits.append(group_length) + + self._splits = splits + self._idx_key_pairs = idx_key_pairs + + def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + if not self._is_inited: + module_init(self, keyed_tensors) + + if self._use_fbgemm_regroup: + values = _get_kts_values(keyed_tensors) + permuted_values = self._permute_pooled_embs_impl(values) + return _to_tensor_dict(self._keys, permuted_values) + else: + permuted_values = _permuted_values( + keyed_tensors, self._idx_key_pairs, self._dim + ) + splitted_values = torch.split(permuted_values, self._splits, dim=self._dim) + return _to_tensor_dict(self._keys, splitted_values) + + def clear_cache(self) -> None: + self._is_inited = False diff --git a/torchrec/modules/tensor_pool.py b/torchrec/modules/tensor_pool.py new file mode 100644 index 000000000..fcb14feb2 --- /dev/null +++ b/torchrec/modules/tensor_pool.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Optional + +import torch +from torchrec.modules.object_pool import ObjectPool +from torchrec.modules.utils import deterministic_dedup + + +@torch.fx.wrap +def _fx_assert_device(ids: torch.Tensor, device: torch.device) -> None: + assert ids.device == device + assert ids.dtype in [torch.int32, torch.int64] + + +@torch.fx.wrap +def _fx_assert_pool_size(ids: torch.Tensor, pool_size: int) -> None: + assert torch.all(ids < pool_size).item() + + +class TensorPool(ObjectPool[torch.Tensor]): + """ + TensorPool represents a collection of torch.Tensor with uniform dimension. + It is effectively a 2D tensor of size [pool_size, dim], where each [1,dim] row + tensor is associated with an unique index which can be set up with update(). + Each row tensor making up the tensor pool can be quried by its index with lookup(). + + Args: + pool_size (int): total number of rows of tensors in the pool + dim (int): dimension that each tensor in the pool + dtype (torch.dtype): dtype of the tensors in the pool + device (Optional[torch.device]): default device + loaded_values (Optional[torch.Tensor]): pre-defined values to initialize the pool + enable_uvm (bool): if set to true, the pool will be allocated on UVM + + Call Args: + ids: 1D torch.Tensor of ids to look up + + Returns: + torch.Tensor of shape [ids.size(0), dim] + + Example:: + + dense_pool = TensorPool( + pool_size=10, + dim=2, + dtype=torch.float + ) + + # Update + ids = torch.Tensor([1, 9]) + update_values = torch.Tensor([[1.0, 2.0],[3.0,4.0]]) + dense_pool.update(ids=ids, values=update_values) + + # Lookup + lookup_values = dense_pool.lookup(ids=ids) + + print(lookup_values) + # tensor([[1., 2.], + # [3., 4.]]) + """ + + def __init__( + self, + pool_size: int, + dim: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + loaded_values: Optional[torch.Tensor] = None, + enable_uvm: bool = False, + ) -> None: + super().__init__() + self._pool_size = pool_size + self._dtype = dtype + # pyre-fixme[4]: Attribute must be annotated. + self._device = device if device is not None else torch.device("meta") + self._dim = dim + self._enable_uvm = enable_uvm + # TODO enable multiple lookup on unsharded module + + self.register_buffer( + "_pool", + torch.zeros( + (self._pool_size, self._dim), + dtype=self._dtype, + device=self._device, + ), + ) + if loaded_values is not None: + self._pool = loaded_values + + @property + def pool_size(self) -> int: + return self._pool_size + + @property + def dim(self) -> int: + return self._dim + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def device(self) -> torch.device: + torch._assert(self._device is not None, "self._device should already be set") + return self._device + + @property + def pool(self) -> torch.Tensor: + return self._pool + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + _fx_assert_device(ids, self._device) + _fx_assert_pool_size(ids, self._pool_size) + return self._pool[ids] + + def update(self, ids: torch.Tensor, values: torch.Tensor) -> None: + assert values.dim() == 2 + assert values.size(1) == self._dim + assert values.dtype == self._dtype + assert values.device == self._device, f"{values.device} != {self._device}" + _fx_assert_device(ids, self._device) + _fx_assert_pool_size(ids, self._pool_size) + + # If duplicate ids are passed in for update, only the last one is kept + deduped_ids, dedup_permutation = deterministic_dedup(ids) + self._pool[deduped_ids] = values[dedup_permutation] + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + return self.lookup(ids) diff --git a/torchrec/modules/tests/test_activation.py b/torchrec/modules/tests/test_activation.py index 42e76f9b5..0b084d5c2 100644 --- a/torchrec/modules/tests/test_activation.py +++ b/torchrec/modules/tests/test_activation.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch diff --git a/torchrec/modules/tests/test_code_quality.py b/torchrec/modules/tests/test_code_quality.py index 621f0b37d..858f9d020 100644 --- a/torchrec/modules/tests/test_code_quality.py +++ b/torchrec/modules/tests/test_code_quality.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import inspect import sys import unittest diff --git a/torchrec/modules/tests/test_crossnet.py b/torchrec/modules/tests/test_crossnet.py index 9686bbe3f..2732431ad 100644 --- a/torchrec/modules/tests/test_crossnet.py +++ b/torchrec/modules/tests/test_crossnet.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch @@ -16,6 +18,7 @@ VectorCrossNet, ) + # unit test for Full Rank CrossNet: CrossNet class TestCrossNet(unittest.TestCase): def test_cross_net_numercial_forward(self) -> None: diff --git a/torchrec/modules/tests/test_deepfm.py b/torchrec/modules/tests/test_deepfm.py index f5cef493a..f2998bbea 100644 --- a/torchrec/modules/tests/test_deepfm.py +++ b/torchrec/modules/tests/test_deepfm.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index c35404f1a..5c4c9f281 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest +from functools import partial import torch import torch.fx @@ -21,10 +24,18 @@ class EmbeddingBagCollectionTest(unittest.TestCase): def test_unweighted(self) -> None: eb1_config = EmbeddingBagConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + init_fn=partial(torch.nn.init.normal_, mean=0.0, std=1.5), ) eb2_config = EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + init_fn=partial(torch.nn.init.uniform_, a=-0.036, b=0.036), ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) @@ -117,6 +128,38 @@ def test_weighted(self) -> None: self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + def test_forward_with_meta_device(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection( + tables=[eb1_config, eb2_config], + is_weighted=True, + device=torch.device("meta"), + ) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7], device="meta"), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12], device="meta"), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7], + device="meta", + ), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (2, 10)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + self.assertEqual(pooled_embeddings.values().device, torch.device("meta")) + def test_fx(self) -> None: eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] @@ -167,8 +210,22 @@ def test_device(self) -> None: config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) + + # test with device from input ebc = EmbeddingBagCollection(tables=[config], device=torch.device("meta")) self.assertEqual(torch.device("meta"), ebc.embedding_bags["t1"].weight.device) + self.assertEqual(torch.device("meta"), ebc.device) + + # test with device from context manager + with torch.device("meta"): + ebc = EmbeddingBagCollection(tables=[config]) + self.assertEqual(torch.device("meta"), ebc.embedding_bags["t1"].weight.device) + self.assertEqual(torch.device("meta"), ebc.device) + + # test default device is cpu + ebc = EmbeddingBagCollection(tables=[config]) + self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) + self.assertEqual(torch.device("cpu"), ebc.device) class EmbeddingCollectionTest(unittest.TestCase): @@ -292,7 +349,7 @@ def test_device(self) -> None: tables=[config], device=torch.device("meta"), ) - self.assertEquals(torch.device("meta"), ec.embeddings["t1"].weight.device) + self.assertEqual(torch.device("meta"), ec.embeddings["t1"].weight.device) def test_duplicate_config_name_fails(self) -> None: e1_config = EmbeddingConfig( diff --git a/torchrec/modules/tests/test_feature_processor.py b/torchrec/modules/tests/test_feature_processor.py index acab3b0a2..451a104ac 100644 --- a/torchrec/modules/tests/test_feature_processor.py +++ b/torchrec/modules/tests/test_feature_processor.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest import torch diff --git a/torchrec/modules/tests/test_feature_processor_.py b/torchrec/modules/tests/test_feature_processor_.py new file mode 100644 index 000000000..508ca0012 --- /dev/null +++ b/torchrec/modules/tests/test_feature_processor_.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from torchrec.distributed.utils import copy_to_device, init_parameters + +from torchrec.fx.tracer import symbolic_trace +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class PositionWeightedModuleTest(unittest.TestCase): + def test_populate_weights(self) -> None: + pw = PositionWeightedModule(max_feature_length=10) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + features = features.to_dict() + + jt = features["f1"] + weighted_features = pw(jt) + + self.assertEqual(weighted_features.weights().size(), (3,)) + + pw_f1_ref = torch.gather( + pw.state_dict()["position_weight"], 0, torch.tensor([0, 1, 0]) + ) + + pw_f1 = weighted_features.weights().detach() + self.assertTrue(torch.allclose(pw_f1_ref, pw_f1)) + + position_weighted_module_gm = symbolic_trace(pw) + position_weighted_module_gm_script = torch.jit.script( + position_weighted_module_gm + ) + + weighted_features_gm_script = position_weighted_module_gm_script(jt) + torch.testing.assert_close( + weighted_features.values(), weighted_features_gm_script.values() + ) + torch.testing.assert_close( + weighted_features.lengths(), weighted_features_gm_script.lengths() + ) + + # TODO: this test is not being run + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rematerialize_from_meta(self) -> None: + pw = PositionWeightedModule(max_feature_length=10, device=torch.device("meta")) + self.assertTrue(pw.position_weight.is_meta) + + # Re-materialize on cuda + init_parameters(pw, torch.device("cuda")) + self.assertTrue(not pw.position_weight.is_meta) + torch.testing.assert_close( + pw.position_weight, torch.ones_like(pw.position_weight) + ) + + +class PositionWeightedCollectionModuleTest(unittest.TestCase): + def test_populate_weights(self) -> None: + position_weighted_module_collection = PositionWeightedModuleCollection( + {"f1": 10, "f2": 10} + ) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_kjt = position_weighted_module_collection(features) + + position_weighted_module_collection_gm = symbolic_trace( + position_weighted_module_collection + ) + position_weighted_module_collection_gm_script = torch.jit.script( + position_weighted_module_collection_gm + ) + fp_kjt_gm_script = position_weighted_module_collection_gm_script(features) + + torch.testing.assert_close(fp_kjt.values(), fp_kjt_gm_script.values()) + torch.testing.assert_close(fp_kjt.lengths(), fp_kjt_gm_script.lengths()) + torch.testing.assert_close( + fp_kjt.length_per_key(), fp_kjt_gm_script.length_per_key() + ) + + empty_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[], + values=torch.tensor([], dtype=torch.int32), + lengths=torch.tensor([], dtype=torch.int32), + ) + + empty_fp_kjt = position_weighted_module_collection(empty_kjt) + empty_fp_kjt_gm_script = position_weighted_module_collection_gm_script( + empty_kjt + ) + + torch.testing.assert_close( + empty_fp_kjt.values(), empty_fp_kjt_gm_script.values() + ) + torch.testing.assert_close( + empty_fp_kjt.lengths(), empty_fp_kjt_gm_script.lengths() + ) + torch.testing.assert_close( + empty_fp_kjt.length_per_key(), empty_fp_kjt_gm_script.length_per_key() + ) + + # TODO: this test is not being run + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rematerialize_from_meta(self) -> None: + pwmc = PositionWeightedModuleCollection( + max_feature_lengths={"f1": 10, "f2": 10}, + device=torch.device("meta"), + ) + self.assertTrue(all(param.is_meta for param in pwmc.position_weights.values())) + + # Re-materialize on cuda + init_parameters(pwmc, torch.device("cuda")) + for key, param in pwmc.position_weights.items(): + self.assertTrue(not param.is_meta) + self.assertTrue(pwmc.position_weights_dict[key] is param) + torch.testing.assert_close(param, torch.ones_like(param)) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs available", + ) + def test_copy(self) -> None: + pwmc = PositionWeightedModuleCollection( + max_feature_lengths={"f1": 10, "f2": 10}, + device=torch.device("cpu"), + ) + + self.assertTrue( + all(param.device.type == "cpu" for param in pwmc.position_weights.values()) + ) + self.assertTrue( + all( + param.device.type == "cpu" + for param in pwmc.position_weights_dict.values() + ) + ) + + res = copy_to_device( + pwmc, current_device=torch.device("cpu"), to_device=torch.device("meta") + ) + + # pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is + # not a function. + self.assertTrue(all(param.is_meta for param in res.position_weights.values())) + self.assertTrue( + # pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` + # is not a function. + all(param.is_meta for param in res.position_weights_dict.values()) + ) + + def test_to(self) -> None: + pwmc = PositionWeightedModuleCollection( + max_feature_lengths={"f1": 10, "f2": 10}, + device=torch.device("cpu"), + ) + + self.assertTrue( + all(param.device.type == "cpu" for param in pwmc.position_weights.values()) + ) + self.assertTrue( + all( + param.device.type == "cpu" + for param in pwmc.position_weights_dict.values() + ) + ) + + pwmc.to("meta") + + self.assertTrue(all(param.is_meta for param in pwmc.position_weights.values())) + self.assertTrue( + all(param.is_meta for param in pwmc.position_weights_dict.values()) + ) diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py new file mode 100644 index 000000000..ccb6d7175 --- /dev/null +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import cast + +import torch + +from torchrec.fx.tracer import symbolic_trace +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class PositionWeightedModuleEmbeddingBagCollectionTest(unittest.TestCase): + + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: + ebc = EmbeddingBagCollection( + tables=[ + EmbeddingBagConfig( + name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] + ), + EmbeddingBagConfig( + name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] + ), + ], + is_weighted=True, + ) + feature_processors = { + "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), + "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), + } + return FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) + + def test_position_weighted_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + fp_ebc_gm_script = torch.jit.script(symbolic_trace(fp_ebc)) + pooled_embeddings_gm_script = fp_ebc_gm_script(features) + + torch.testing.assert_close( + pooled_embeddings_gm_script.values(), pooled_embeddings.values() + ) + + torch.testing.assert_close( + pooled_embeddings_gm_script.offset_per_key(), + pooled_embeddings.offset_per_key(), + ) + + def test_position_weighted_module_ebc_with_excessive_features(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # 2 [8] None None + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + # Test split method, FP then EBC + fp, ebc = fp_ebc.split() + fp_kjt = fp(features) + pooled_embeddings_split = ebc(fp_kjt) + + self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings_split.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16]) + + +class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: + ebc = EmbeddingBagCollection( + tables=[ + EmbeddingBagConfig( + name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] + ), + EmbeddingBagConfig( + name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] + ), + ], + is_weighted=True, + ) + + return FeatureProcessedEmbeddingBagCollection( + ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) + ) + + def test_position_weighted_collection_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + fp_ebc_gm_script = torch.jit.script(symbolic_trace(fp_ebc)) + pooled_embeddings_gm_script = fp_ebc_gm_script(features) + + torch.testing.assert_close( + pooled_embeddings_gm_script.values(), pooled_embeddings.values() + ) + + torch.testing.assert_close( + pooled_embeddings_gm_script.offset_per_key(), + pooled_embeddings.offset_per_key(), + ) + + # Test split method, FP then EBC + fp, ebc = fp_ebc.split() + fp_kjt = fp(features) + pooled_embeddings_split = ebc(fp_kjt) + + self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings_split.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16]) diff --git a/torchrec/modules/tests/test_fused_embedding_modules.py b/torchrec/modules/tests/test_fused_embedding_modules.py index 5c8b68d87..b16a14d56 100644 --- a/torchrec/modules/tests/test_fused_embedding_modules.py +++ b/torchrec/modules/tests/test_fused_embedding_modules.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from collections import OrderedDict @@ -15,9 +17,8 @@ import torch import torch.fx import torchrec -from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation +from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from hypothesis import given, settings -from torchrec.fx import symbolic_trace from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, @@ -573,6 +574,7 @@ def test_optimizer_fusion( ).to(device) opt = optimizer_type(ebc.parameters(), **optimizer_kwargs) + # pyre-ignore def run_one_training_step() -> None: fused_pooled_embeddings = fused_ebc(features) @@ -860,7 +862,7 @@ def test_forward_with_state_dict( ).to(device) sequential_embeddings = ec(features) - self.assertEquals( + self.assertEqual( set(sequential_embeddings.keys()), {"f1", "f1_1", "f_shared@t1", "f2", "f_shared@t2", "f3"}, ) @@ -964,6 +966,7 @@ def test_optimizer_fusion( ).to(device) opt = optimizer_type(ec.parameters(), **optimizer_kwargs) + # pyre-ignore def run_one_training_step() -> None: fused_embeddings = fused_ec(features) diff --git a/torchrec/modules/tests/test_itep_embedding_modules.py b/torchrec/modules/tests/test_itep_embedding_modules.py new file mode 100644 index 000000000..3e9bc0801 --- /dev/null +++ b/torchrec/modules/tests/test_itep_embedding_modules.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +import random +import unittest +from typing import Dict, List +from unittest.mock import MagicMock, patch + +import torch +from torchrec import KeyedJaggedTensor +from torchrec.distributed.embedding_types import ShardedEmbeddingTable +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + +from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection +from torchrec.modules.itep_modules import GenericITEPModule + +MOCK_NS: str = "torchrec.modules.itep_modules" + + +class TestITEPEmbeddingBagCollection(unittest.TestCase): + # Setting up the environment for the tests. + def setUp(self) -> None: + # Embedding bag configurations for testing + embedding_bag_config1 = EmbeddingBagConfig( + name="table1", + embedding_dim=4, + num_embeddings=50, + feature_names=["feature1"], + ) + embedding_bag_config2 = EmbeddingBagConfig( + name="table2", + embedding_dim=4, + num_embeddings=40, + feature_names=["feature2"], + ) + unpruned_hash_size_1, unpruned_hash_size_2 = (100, 80) + self._table_name_to_pruned_hash_sizes = {"table1": 50, "table2": 40} + self._table_name_to_unpruned_hash_sizes = { + "table1": unpruned_hash_size_1, + "table2": unpruned_hash_size_2, + } + self._feature_name_to_unpruned_hash_sizes = { + "feature1": unpruned_hash_size_1, + "feature2": unpruned_hash_size_2, + } + self._batch_size = 8 + + # Util function for creating sharded embedding tables from embedding bag configurations. + def embedding_bag_config_to_sharded_table( + config: EmbeddingBagConfig, + ) -> ShardedEmbeddingTable: + return ShardedEmbeddingTable( + name=config.name, + embedding_dim=config.embedding_dim, + num_embeddings=config.num_embeddings, + feature_names=config.feature_names, + ) + + sharded_et1 = embedding_bag_config_to_sharded_table(embedding_bag_config1) + sharded_et2 = embedding_bag_config_to_sharded_table(embedding_bag_config2) + + # Create test ebc + self._embedding_bag_collection = EmbeddingBagCollection( + tables=[ + embedding_bag_config1, + embedding_bag_config2, + ], + device=torch.device("cuda"), + ) + + # Create a mock object for tbe lookups + self._mock_list_emb_tables = [ + sharded_et1, + sharded_et2, + ] + self._mock_lookups = [MagicMock()] + self._mock_lookups[0]._emb_modules = [MagicMock()] + self._mock_lookups[0]._emb_modules[0]._config = MagicMock() + self._mock_lookups[0]._emb_modules[ + 0 + ]._config.embedding_tables = self._mock_list_emb_tables + + def generate_input_kjt_cuda( + self, feature_name_to_unpruned_hash_sizes: Dict[str, int], use_vbe: bool = False + ) -> KeyedJaggedTensor: + keys = [] + values = [] + lengths = [] + cuda_device = torch.device("cuda") + + # Input KJT uses unpruned hash size (same as sigrid hash), and feature names + for key, unpruned_hash_size in feature_name_to_unpruned_hash_sizes.items(): + value = [] + length = [] + for _ in range(self._batch_size): + L = random.randint(0, 8) + for _ in range(L): + index = random.randint(0, unpruned_hash_size - 1) + value.append(index) + length.append(L) + keys.append(key) + values += value + lengths += length + + # generate kjt + if use_vbe: + inverse_indices_list = [] + inverse_indices = None + num_keys = len(keys) + deduped_batch_size = len(lengths) // num_keys + # Fix the number of samples after duplicate to 2x the number of + # deduplicated ones + full_batch_size = deduped_batch_size * 2 + stride_per_key_per_rank = [] + + for _ in range(num_keys): + stride_per_key_per_rank.append([deduped_batch_size]) + # Generate random inverse indices for each key + keyed_inverse_indices = torch.randint( + low=0, + high=deduped_batch_size, + size=(full_batch_size,), + dtype=torch.int32, + device=cuda_device, + ) + inverse_indices_list.append(keyed_inverse_indices) + inverse_indices = ( + keys, + torch.stack(inverse_indices_list), + ) + + input_kjt_cuda = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor( + copy.deepcopy(values), + dtype=torch.int32, + device=cuda_device, + ), + lengths=torch.tensor( + copy.deepcopy(lengths), + dtype=torch.int32, + device=cuda_device, + ), + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, + ) + else: + input_kjt_cuda = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor( + copy.deepcopy(values), + dtype=torch.int32, + device=cuda_device, + ), + lengths=torch.tensor( + copy.deepcopy(lengths), + dtype=torch.int32, + device=cuda_device, + ), + ) + + return input_kjt_cuda + + def generate_expected_address_lookup_buffer( + self, + list_et: List[ShardedEmbeddingTable], + table_name_to_unpruned_hash_sizes: Dict[str, int], + table_name_to_pruned_hash_sizes: Dict[str, int], + ) -> torch.Tensor: + + address_lookup = [] + for et in list_et: + table_name = et.name + unpruned_hash_size = table_name_to_unpruned_hash_sizes[table_name] + pruned_hash_size = table_name_to_pruned_hash_sizes[table_name] + for idx in range(unpruned_hash_size): + if idx < pruned_hash_size: + address_lookup.append(idx) + else: + address_lookup.append(0) + + return torch.tensor(address_lookup, dtype=torch.int64) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_init_itep_module(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=5, + ) + + # Check the address lookup and row util after initialization + expected_address_lookup = self.generate_expected_address_lookup_buffer( + self._mock_list_emb_tables, + self._table_name_to_unpruned_hash_sizes, + self._table_name_to_pruned_hash_sizes, + ) + expetec_row_util = torch.zeros( + expected_address_lookup.shape, dtype=torch.float32 + ) + torch.testing.assert_close( + expected_address_lookup, + itep_module.address_lookup.cpu(), + atol=0, + rtol=0, + equal_nan=True, + ) + torch.testing.assert_close( + expetec_row_util, + itep_module.row_util.cpu(), + atol=1.0e-5, + rtol=1.0e-5, + equal_nan=True, + ) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_init_itep_module_without_pruned_table(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes={}, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=5, + ) + + self.assertEqual(itep_module.address_lookup.cpu().shape, torch.Size([0])) + self.assertEqual(itep_module.row_util.cpu().shape, torch.Size([0])) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_train_forward(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + def test_train_forward_vbe(self) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(5): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes, use_vbe=True + ) + _ = itep_ebc(input_kjt) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + # Mock out reset_weight_momentum to count calls + @patch(f"{MOCK_NS}.GenericITEPModule.reset_weight_momentum") + def test_check_pruning_schedule( + self, + mock_reset_weight_momentum: MagicMock, + ) -> None: + random.seed(1) + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # Check that reset_weight_momentum is called + self.assertEqual(mock_reset_weight_momentum.call_count, 5) + + # pyre-ignores + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Skip when not enough GPUs available", + ) + # Mock out reset_weight_momentum to count calls + @patch(f"{MOCK_NS}.GenericITEPModule.reset_weight_momentum") + def test_eval_forward( + self, + mock_reset_weight_momentum: MagicMock, + ) -> None: + itep_module = GenericITEPModule( + table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, + lookups=self._mock_lookups, + enable_pruning=True, + pruning_interval=500, + ) + + itep_ebc = ITEPEmbeddingBagCollection( + embedding_bag_collection=self._embedding_bag_collection, + itep_module=itep_module, + ) + + # Set eval mode + itep_ebc.eval() + + # Test forward 2000 times + for _ in range(2000): + input_kjt = self.generate_input_kjt_cuda( + self._feature_name_to_unpruned_hash_sizes + ) + _ = itep_ebc(input_kjt) + + # Check that reset_weight_momentum is not called + self.assertEqual(mock_reset_weight_momentum.call_count, 0) diff --git a/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py b/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py new file mode 100644 index 000000000..041b95441 --- /dev/null +++ b/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import torch +from torchrec.modules.keyed_jagged_tensor_pool import KeyedJaggedTensorPool +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class KeyedJaggedTensorPoolTest(unittest.TestCase): + def test_update_lookup( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + values_dtype = torch.int64 + + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + device=device, + ) + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + keyed_jagged_tensor_pool.update( + ids=torch.tensor([2, 0, 1, 3], device=device), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + [1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14], + dtype=values_dtype, + device=device, + ), + lengths=torch.tensor( + [1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device + ), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([2, 0], device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [1, 3, 3, 11, 13, 13, 13], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 2, 1, 3], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([1, 3, 0, 2], device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [2, 2, 4, 3, 3, 1, 12, 12, 14, 14, 14, 14, 13, 13, 13, 11], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [2, 1, 2, 1, 2, 4, 3, 1], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + def test_input_permute( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + values_dtype = torch.int32 + + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + device=device, + ) + + # init global state is + # 4 8 + # f1 f2 f3 + # [3,3] . [13,13,13] [23] + # [2,2] . [12,12] [22, 22, 22] + # [1] . [11] [21, 21] + # [4] [14,14,14,14] [] + + keyed_jagged_tensor_pool.update( + ids=torch.tensor([2, 0, 1, 3], device=device), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f2", "f3", "f1"], + values=torch.tensor( + [ + 11, + 13, + 13, + 13, + 12, + 12, + 14, + 14, + 14, + 14, + 21, + 21, + 23, + 22, + 22, + 22, + 1, + 3, + 3, + 2, + 2, + 4, + ], + dtype=values_dtype, + device=device, + ), + lengths=torch.tensor( + [1, 3, 2, 4, 2, 1, 3, 0, 1, 2, 2, 1], dtype=torch.int, device=device + ), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([2, 0], device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [[1], [3, 3]], + # "f2": [[11], [13, 13, 13]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [1, 3, 3, 11, 13, 13, 13], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [1, 2, 1, 3], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([1, 3, 0, 2], device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [[2, 2], [4], [3, 3], [1]], + # "f2": [[12, 12], [14, 14, 14, 14], [13, 13, 13], [11]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [2, 2, 4, 3, 3, 1, 12, 12, 14, 14, 14, 14, 13, 13, 13, 11], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [2, 1, 2, 1, 2, 4, 3, 1], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + def test_conflict( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + values_dtype = torch.int32 + + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + device=device, + ) + + # input is + # ids f1 f2 + # 2 [1] [11] + # 0 [3,3] . [13,13,13] + # 2 [2,2] [12,12] + # 3 [4] [14,14,14,14] + + keyed_jagged_tensor_pool.update( + ids=torch.tensor([2, 0, 2, 3], device=device), + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + [1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14], + dtype=values_dtype, + device=device, + ), + lengths=torch.tensor( + [1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device + ), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([2, 0], device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [[2,2], [3, 3]], + # "f2": [[12, 12], [13, 13, 13]] + # }) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor( + [2, 2, 3, 3, 12, 12, 13, 13, 13], + dtype=values_dtype, + device=torch.device("cpu"), + ), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor( + [2, 2, 2, 3], + dtype=torch.int, + device=torch.device("cpu"), + ), + ) + + def test_empty_lookup( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + values_dtype = torch.int32 + + keyed_jagged_tensor_pool = KeyedJaggedTensorPool( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + values_dtype=values_dtype, + device=device, + ) + + # init global state is + # 4 8 + # f1 f2 + # [3,3] . [13,13,13] + # [2,2] . [12,12] + # [1] . [11] + # [4] [14,14,14,14] + + ids = torch.tensor([2, 0, 1, 3], device=device) + keyed_jagged_tensor_pool.update( + ids=ids, + values=KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.tensor( + [1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14], + dtype=values_dtype, + device=device, + ), + lengths=torch.tensor( + [1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device + ), + ), + ) + + kjt = keyed_jagged_tensor_pool.lookup( + ids=torch.tensor([], dtype=ids.dtype, device=device), + ) + + # expected values + # KeyedJaggedTensor({ + # "f1": [], + # "f2": [], + # }) + + self.assertEqual(kjt.keys(), ["f1", "f2"]) + + torch.testing.assert_close( + kjt.values().cpu(), + torch.tensor([], dtype=values_dtype, device=torch.device("cpu")), + ) + + torch.testing.assert_close( + kjt.lengths().cpu(), + torch.tensor([], dtype=torch.int, device=torch.device("cpu")), + ) diff --git a/torchrec/modules/tests/test_kjt_pool_lookup.py b/torchrec/modules/tests/test_kjt_pool_lookup.py new file mode 100644 index 000000000..3b9617f3e --- /dev/null +++ b/torchrec/modules/tests/test_kjt_pool_lookup.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import torch +from torchrec.modules.object_pool_lookups import UVMCachingInt64Lookup +from torchrec.sparse.jagged_tensor import JaggedTensor + + +class KeyedJaggedTensorPoolLookupTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_uvm_caching_int64_lookup( + self, + ) -> None: + device = torch.device("cuda:0") + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + lookup = UVMCachingInt64Lookup( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + is_weighted=False, + device=device, + ) + ids = torch.tensor([0, 1, 2, 3], device=device) + jt_values = torch.tensor( + [1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14], + dtype=torch.int64, + device=device, + ) + jt_lengths = torch.tensor( + [1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device + ) + + lookup.update( + ids=ids, + values=JaggedTensor( + jt_values, + lengths=jt_lengths, + ), + ) + + torch.testing.assert_close(lookup.lookup(ids).values(), jt_values) + + INT64_VALUE_SHIFT = int(3e9) + lookup.update( + ids=ids, + values=JaggedTensor( + jt_values + INT64_VALUE_SHIFT, + lengths=jt_lengths, + ), + ) + + torch.testing.assert_close( + lookup.lookup(ids).values(), jt_values + INT64_VALUE_SHIFT + ) diff --git a/torchrec/modules/tests/test_lazy_extension.py b/torchrec/modules/tests/test_lazy_extension.py index eb251bd8f..0a3491225 100644 --- a/torchrec/modules/tests/test_lazy_extension.py +++ b/torchrec/modules/tests/test_lazy_extension.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import inspect import re import unittest @@ -66,13 +68,12 @@ def test_source_code_parity_on_infer_parameters(self) -> None: # reproduce the only changes: expected_lazy_ext_infer_parameters_src = original_infer_parameters_src.replace( - "def _infer_parameters(self: _LazyProtocol, module, input):", - "def _infer_parameters(self: _LazyExtensionProtocol, module, input, kwargs) -> None:", + "def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None):", + "def _infer_parameters(self: _LazyExtensionProtocol, module, args, kwargs) -> None:", ).replace( - "module.initialize_parameters(*input)", - "module.initialize_parameters(*input, **kwargs)", + "module.initialize_parameters(*args)", + "module.initialize_parameters(*args, **kwargs)", ) - self.assertEqual( lazy_ext_infer_parameters_src, expected_lazy_ext_infer_parameters_src, @@ -150,8 +151,6 @@ def input_only_hook( return input[0] + 1 m = TestModule() - # pyre-fixme[6]: Expected `(...) -> None` for 1st param but got `(module: - # Module, input: Tuple[Tensor, ...]) -> Tensor`. m.register_forward_pre_hook(input_only_hook) output = m(torch.zeros(2, 2)) self.assertTrue(torch.allclose(output, torch.ones(2, 2))) @@ -291,8 +290,8 @@ def init_weights(m: torch.nn.Module) -> None: # and the function will be applied right after first forward pass. net = torch.nn.Sequential(TestModule(), TestModule()) net = lazy_apply(net, init_weights) - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.assertTrue(torch.allclose(net[0].param, torch.tensor(1.0))) net(torch.tensor(2.0)) - # pyre-ignore[29] + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. self.assertTrue(torch.allclose(net[0].param, torch.tensor(7.0))) diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py new file mode 100644 index 000000000..58c4fa466 --- /dev/null +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from copy import deepcopy +from typing import cast, Dict + +import torch +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + ManagedCollisionModule, + MCHManagedCollisionModule, +) +from torchrec.sparse.jagged_tensor import ( + ComputeJTDictToKJT, + KeyedJaggedTensor, + KeyedTensor, +) + + +class MCHManagedCollisionEmbeddingBagCollectionTest(unittest.TestCase): + def test_zch_ebc_ec_train(self) -> None: + device = torch.device("cpu") + zch_size = 20 + update_interval = 2 + update_size = 10 + + embedding_bag_configs = [ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + + ec = EmbeddingCollection( + tables=embedding_configs, + device=device, + ) + + mc_modules = { + "t1": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=zch_size, + device=device, + eviction_interval=update_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc_ebc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + embedding_configs=embedding_configs, + ) + mc_ebc = ManagedCollisionEmbeddingBagCollection( + ebc, + mcc_ebc, + return_remapped_features=True, + ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, + ) + + mc_modules = [mc_ebc, mc_ec] + + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2", "f3"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + torch.arange( + 1000 + 2 * update_size, + 1000 + 3 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((3 * update_size,), dtype=torch.int64), + weights=None, + ) + + for mc_module in mc_modules: + + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), + zch_size - 1, + ) # (ZCH-1 slots) + + out1, remapped_kjt1 = mc_module.forward(update_one) + + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), + zch_size - 1, + ) # prior update, ZCH-1 slots + + out2, remapped_kjt2 = mc_module.forward(update_one) + + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), 0 + ) # post update, 0 slots + + assert remapped_kjt1 is not None + assert remapped_kjt1.keys() == ["f1", "f2"] + assert torch.all( + remapped_kjt1["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert remapped_kjt2 is not None + assert remapped_kjt2.keys() == ["f1", "f2"] + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) + ) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) + + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out1, Dict)) + self.assertTrue(isinstance(out2, Dict)) + self.assertEqual(out1["f1"].values().size(), (update_size, 8)) + self.assertEqual(out2["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out1, KeyedTensor)) + self.assertTrue(isinstance(out2, KeyedTensor)) + self.assertEqual(out1["f1"].size(), (update_size, 8)) + self.assertEqual(out2["f2"].size(), (update_size, 8)) + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + out3, remapped_kjt3 = mc_module.forward(update_two) + out4, remapped_kjt4 = mc_module.forward(update_two) + + assert remapped_kjt3 is not None + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) + + assert remapped_kjt4 is not None + assert torch.all( + remapped_kjt4["f1"].values() + == torch.cat( + [ + torch.arange(1, 10, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) + assert torch.all( + remapped_kjt4["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor( + [0], dtype=torch.int64 + ), # assigned first open slot + ] + ) + ) + + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out3, Dict)) + self.assertTrue(isinstance(out4, Dict)) + self.assertEqual(out3["f1"].values().size(), (update_size, 8)) + self.assertEqual(out4["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out3, KeyedTensor)) + self.assertTrue(isinstance(out4, KeyedTensor)) + self.assertEqual(out3["f1"].size(), (update_size, 8)) + self.assertEqual(out4["f2"].size(), (update_size, 8)) + + def test_zch_ebc_ec_eval(self) -> None: + device = torch.device("cpu") + zch_size = 20 + update_interval = 2 + update_size = 10 + + embedding_bag_configs = [ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + ec = EmbeddingCollection( + tables=embedding_configs, + device=device, + ) + mc_modules = { + "t1": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=zch_size, + device=device, + eviction_interval=update_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc_ebc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + embedding_configs=embedding_configs, + ) + mc_ebc = ManagedCollisionEmbeddingBagCollection( + ebc, + mcc_ebc, + return_remapped_features=True, + ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, + ) + + mc_modules = [mc_ebc, mc_ec] + + for mc_module in mc_modules: + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt1 = mc_module.forward(update_one) + _, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) + ) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) + + # Trigger eval mode, zch should not update + mc_module.eval() + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt3 = mc_module.forward(update_two) + _, remapped_kjt4 = mc_module.forward(update_two) + + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) + + assert torch.all( + remapped_kjt4["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values() + ) + + def test_mc_collection_traceable(self) -> None: + device = torch.device("cpu") + zch_size = 20 + update_interval = 2 + + embedding_configs = [ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + mc_modules = { + "t1": cast( + ManagedCollisionModule, + MCHManagedCollisionModule( + zch_size=zch_size, + device=device, + input_hash_size=2 * zch_size, + eviction_interval=update_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ), + ), + } + mcc = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + mcc.train(False) + symbolic_trace(mcc, leaf_modules=[ComputeJTDictToKJT.__name__]) diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py new file mode 100644 index 000000000..8fac2ac25 --- /dev/null +++ b/torchrec/modules/tests/test_mc_modules.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.modules.mc_modules import ( + average_threshold_filter, + DistanceLFU_EvictionPolicy, + dynamic_threshold_filter, + LFU_EvictionPolicy, + LRU_EvictionPolicy, + MCHManagedCollisionModule, + probabilistic_threshold_filter, +) +from torchrec.sparse.jagged_tensor import JaggedTensor + + +class TestEvictionPolicy(unittest.TestCase): + def test_lfu_eviction(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy(), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + + # insert some values to zch + # we have 10 counts of 4 and 1 count of 5 + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5]) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + mc_module._mch_counts[0:2] = torch.tensor([10, 1]) + + ids = [3, 4, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 10] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + # 5, empty, empty, empty will be evicted + # 6, 7, 8 will be added + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [4, 6, 7, 8, torch.iinfo(torch.int64).max], + ) + # 11 counts of 5, 3 counts of 6, 3 counts of 7, 3 counts of 8 + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [11, 3, 3, 3, torch.iinfo(torch.int64).max]) + + def test_lru_eviction(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LRU_EvictionPolicy(decay_exponent=1.0), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [0] * 5) + + ids = [5, 6, 7] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + self.assertEqual(mc_module.open_slots().item(), 1) + ids = [3, 4, 5] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + self.assertEqual(mc_module.open_slots().item(), 0) + ids = [7, 8] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + self.assertEqual(mc_module.open_slots().item(), 0) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [3, 4, 7, 8, torch.iinfo(torch.int64).max], + ) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3]) + self.assertEqual(mc_module.open_slots().item(), 0) + + def test_distance_lfu_eviction(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=DistanceLFU_EvictionPolicy(decay_exponent=1.0), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [0] * 5) + + ids = [5, 5, 5, 5, 5, 6] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + ids = [3, 4] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + ids = [7, 8] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [3, 5, 7, 8, torch.iinfo(torch.int64).max], + ) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [1, 5, 1, 1, torch.iinfo(torch.int64).max]) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [2, 1, 3, 3, 3]) + + def test_distance_lfu_eviction_fast_decay(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=DistanceLFU_EvictionPolicy(decay_exponent=10.0), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [0] * 5) + + ids = [5, 5, 5, 5, 5, 6] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + ids = [3, 4] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + ids = [7, 8] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [3, 4, 7, 8, torch.iinfo(torch.int64).max], + ) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [1, 1, 1, 1, torch.iinfo(torch.int64).max]) + _mch_last_access_iter = mc_module._mch_last_access_iter + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3]) + + def test_dynamic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: dynamic_threshold_filter( + tensor, threshold_skew_multiplier=0.75 + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + + ids = [5, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 2, 2, 1] + # threshold is len(ids) / unique_count(ids) * threshold_skew_multiplier + # = 15 / 5 * 0.5 = 2.25 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [3, 4, 5, torch.iinfo(torch.int64).max, torch.iinfo(torch.int64).max], + ) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [3, 4, 5, 0, torch.iinfo(torch.int64).max]) + + def test_average_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=average_threshold_filter + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + + # insert some values to zch + # we have 10 counts of 4 and 1 count of 5 + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + mc_module._mch_sorted_raw_ids[0:2] = torch.tensor([4, 5]) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque... + mc_module._mch_counts[0:2] = torch.tensor([10, 1]) + + ids = [3, 4, 5, 6, 6, 6, 7, 8, 8, 9, 10] + # threshold is 1.375 + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + mc_module.profile(features) + + # empty, empty will be evicted + # 6, 8 will be added + # 7 is not added because it's below the average threshold + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + list(_mch_sorted_raw_ids), + [4, 5, 6, 8, torch.iinfo(torch.int64).max], + ) + # count for 4 is not updated since it's below the average threshold + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [10, 1, 3, 2, torch.iinfo(torch.int64).max]) + + def test_probabilistic_threshold_filter(self) -> None: + mc_module = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy( + threshold_filtering_func=lambda tensor: probabilistic_threshold_filter( + tensor, + per_id_probability=0.01, + ) + ), + eviction_interval=1, + input_hash_size=100, + ) + + # check initial state + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_sorted_raw_ids), [torch.iinfo(torch.int64).max] * 5) + _mch_counts = mc_module._mch_counts + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Tensor, Module]`. + self.assertEqual(list(_mch_counts), [0] * 5) + + unique_ids = [5, 4, 3, 2, 1] + id_counts = [100, 80, 60, 40, 10] + ids = [id for id, count in zip(unique_ids, id_counts) for _ in range(count)] + # chance of being added is [0.63, 0.55, 0.45, 0.33] + features: Dict[str, JaggedTensor] = { + "f1": JaggedTensor( + values=torch.tensor(ids, dtype=torch.int64), + lengths=torch.tensor([1] * len(ids), dtype=torch.int64), + ) + } + + torch.manual_seed(42) + for _ in range(10): + mc_module.profile(features) + + _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids + self.assertEqual( + # pyre-fixme[29]: `Union[(self: TensorBase) -> list[Any], Tensor, + # Module]` is not a function. + sorted(_mch_sorted_raw_ids.tolist()), + [2, 3, 4, 5, torch.iinfo(torch.int64).max], + ) + # _mch_counts is like + # [80, 180, 160, 800, 9223372036854775807] + + def test_fx_jit_script_not_training(self) -> None: + model = MCHManagedCollisionModule( + zch_size=5, + device=torch.device("cpu"), + eviction_policy=LFU_EvictionPolicy(), + eviction_interval=1, + input_hash_size=100, + ) + + model.train(False) + gm = torch.fx.symbolic_trace(model) + torch.jit.script(gm) diff --git a/torchrec/modules/tests/test_mlp.py b/torchrec/modules/tests/test_mlp.py index ea2d78ff6..069d071b9 100644 --- a/torchrec/modules/tests/test_mlp.py +++ b/torchrec/modules/tests/test_mlp.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import Callable, List, Union diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py new file mode 100644 index 000000000..14a79605e --- /dev/null +++ b/torchrec/modules/tests/test_regroup.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +import torch.fx + +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor +from torchrec.sparse.tests.utils import build_groups, build_kts + + +class KTRegroupAsDictTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + self.num_groups = 2 + self.keys = ["user", "object"] + self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + def new_kts(self) -> None: + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + + def test_regroup_backward_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + + # first run + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # second run + self.new_kts() + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_backward(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_fx_and_jit_regroup(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) + + def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) diff --git a/torchrec/modules/tests/test_tensor_pool.py b/torchrec/modules/tests/test_tensor_pool.py new file mode 100644 index 000000000..15c1f7bba --- /dev/null +++ b/torchrec/modules/tests/test_tensor_pool.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import torch + +from torchrec.modules.tensor_pool import TensorPool + + +class TensorPoolTest(unittest.TestCase): + def test_update_lookup( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + pool_size = 10 + dim = 3 + batch_size = 5 + dense_pool = TensorPool( + pool_size=pool_size, + dim=dim, + dtype=torch.float, + device=device, + ) + update_ids = [1, 9, 4, 2, 6] + ids_to_row = {1: 0, 9: 1, 4: 2, 2: 3, 6: 4} + ids = torch.tensor(update_ids, dtype=torch.int, device=device) + reference_values = torch.rand( + (batch_size, dim), dtype=torch.float, device=device + ) + dense_pool.update(ids=ids, values=reference_values) + + lookup_ids = torch.randint( + low=0, high=pool_size, size=(batch_size,), dtype=torch.int, device=device + ) + lookup_values = dense_pool.lookup(ids=lookup_ids) + for i in range(batch_size): + if lookup_ids[i] in update_ids: + # pyre-ignore + lookup_id: int = lookup_ids[i].int().item() + torch.testing.assert_close( + reference_values[ids_to_row[lookup_id]], + lookup_values[i], + ) + else: + torch.testing.assert_close( + lookup_values[i], + torch.zeros(3, dtype=torch.float, device=device), + msg=f"{dense_pool._pool[lookup_ids[i]]}", + ) + + torch.testing.assert_close(dense_pool.pool_size, pool_size) + + def test_conflict( + self, + ) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + pool_size = 10 + dim = 3 + batch_size = 5 + dense_pool = TensorPool( + pool_size=pool_size, + dim=dim, + dtype=torch.float, + device=device, + ) + update_ids = [1, 9, 4, 1, 6] + ids_to_row = {9: 1, 4: 2, 1: 3, 6: 4} # The first 1 is deduped and removed + ids = torch.tensor(update_ids, dtype=torch.int, device=device) + reference_values = torch.rand( + (batch_size, dim), dtype=torch.float, device=device + ) + dense_pool.update(ids=ids, values=reference_values) + + lookup_ids = torch.randint( + low=0, high=pool_size, size=(batch_size,), dtype=torch.int, device=device + ) + lookup_values = dense_pool.lookup(ids=lookup_ids) + for i in range(batch_size): + if lookup_ids[i] in update_ids: + # pyre-ignore + lookup_id: int = lookup_ids[i].int().item() + torch.testing.assert_close( + reference_values[ids_to_row[lookup_id]], + lookup_values[i], + ) + else: + torch.testing.assert_close( + lookup_values[i], + torch.zeros(3, dtype=torch.float, device=device), + msg=f"{dense_pool._pool[lookup_ids[i]]}", + ) + + torch.testing.assert_close(dense_pool.pool_size, pool_size) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index df56ee63a..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -5,10 +5,50 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy -from typing import Callable, Iterable, Tuple, Union +from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor +from torch.profiler import record_function +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.streamable import Multistreamable +from torchrec.types import CacheMixin + +torch.fx.wrap("len") + + +@dataclass +class SequenceVBEContext(Multistreamable): + recat: torch.Tensor + unpadded_lengths: torch.Tensor + reindexed_lengths: torch.Tensor + reindexed_length_per_key: List[int] + reindexed_values: Optional[torch.Tensor] = None + + def record_stream(self, stream: torch.Stream) -> None: + self.recat.record_stream(stream) + self.unpadded_lengths.record_stream(stream) + self.reindexed_lengths.record_stream(stream) + if self.reindexed_values is not None: + self.reindexed_values.record_stream(stream) + + +@torch.fx.wrap +def _fx_to_list(tensor: torch.Tensor) -> List[int]: + return tensor.long().tolist() + + +@torch.fx.wrap +def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor: + """ + Slice tensor. + """ + return tensor[start:end] def extract_module_or_tensor_callable( @@ -104,14 +144,285 @@ def convert_list_of_modules_to_modulelist( # `Iterable[torch.nn.Module]`. len(modules) == sizes[0] + # pyre-fixme[6]: For 1st argument expected `pyre_extensions.PyreReadOnly[Sized]` + # but got `Iterable[Module]`. ), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}" if len(sizes) == 1: return torch.nn.ModuleList(modules) else: # recursively create nested list return torch.nn.ModuleList( - # pyre-fixme[6]: Expected `Iterable[torch.nn.Module]` for 1st param but - # got `Module`. + # pyre-fixme[6]: For 1st argument expected `Iterable[Module]` but got + # `Module`. convert_list_of_modules_to_modulelist(m, sizes[1:]) for m in modules ) + + +def _permute_tensor_by_segments( + tensor: torch.Tensor, + segment_sizes: torch.Tensor, + recat: torch.Tensor, + weights: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + TODO: remove and import from `jagged_tensor.py` once packaging issue is resolved + + Permutes a tensor by segments according to recat tensor. + + For variable stride tensors we permute across length per key, which reduces the + number of permute indices and lengthens each sequence. + `keyed_jagged_index_select_dim1` more efficiently parallelizes work for each permute + index and sequence across multiple thread blocks. + + NOTE: + `keyed_jagged_index_select_dim1` is only supported for CUDA. + """ + if tensor.device.type == "cuda": + output = torch.ops.fbgemm.keyed_jagged_index_select_dim1( + values=tensor, + lengths=segment_sizes, + offsets=torch.ops.fbgemm.asynchronous_complete_cumsum(segment_sizes), + indices=recat, + batch_size=segment_sizes.numel(), + weights=weights, + selected_lengths_sum=output_size, + ) + permuted_tensor = output[0] + permuted_weights = None if weights is None else output[2] + else: + ( + _, + permuted_tensor, + permuted_weights, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + recat, + segment_sizes, + tensor, + weights, + output_size, + ) + return permuted_tensor, permuted_weights + + +def _vbe_reindex( + embeddings: torch.Tensor, + seq_vbe_ctx: SequenceVBEContext, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], Optional[torch.Tensor]]: + """ + Reindexes embeddings for variable batch size per feature scenarios. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[int], torch.Tensor]: the reindexed + embeddings, lengths, length_per_key, and values + """ + dim = embeddings.shape[1] + output_size = sum(seq_vbe_ctx.reindexed_length_per_key) * dim + reindexed_embeddings, _ = _permute_tensor_by_segments( + tensor=embeddings.flatten(), + segment_sizes=seq_vbe_ctx.unpadded_lengths * dim, + recat=seq_vbe_ctx.recat, + weights=None, + output_size=output_size, + ) + reindexed_embeddings = reindexed_embeddings.view(-1, dim) + assert len(seq_vbe_ctx.reindexed_lengths.shape) == 2 + return ( + reindexed_embeddings, + seq_vbe_ctx.reindexed_lengths, + seq_vbe_ctx.reindexed_length_per_key, + seq_vbe_ctx.reindexed_values, + ) + + +def construct_jagged_tensors( + embeddings: torch.Tensor, + features: KeyedJaggedTensor, + embedding_names: List[str], + need_indices: bool = False, + features_to_permute_indices: Optional[Dict[str, List[int]]] = None, + original_features: Optional[KeyedJaggedTensor] = None, + reverse_indices: Optional[torch.Tensor] = None, + seq_vbe_ctx: Optional[SequenceVBEContext] = None, +) -> Dict[str, JaggedTensor]: + with record_function("## construct_jagged_tensors ##"): + if original_features is not None: + features = original_features + if reverse_indices is not None: + embeddings = torch.index_select( + embeddings, 0, reverse_indices.to(torch.int32) + ) + ret: Dict[str, JaggedTensor] = {} + + if seq_vbe_ctx is not None: + embeddings, lengths, length_per_key, values = _vbe_reindex( + embeddings=embeddings, seq_vbe_ctx=seq_vbe_ctx + ) + else: + lengths = features.lengths().view(-1, features.stride()) + length_per_key = features.length_per_key() + values = features.values() + + lengths_tuple = torch.unbind(lengths, dim=0) + embeddings_list = torch.split(embeddings, length_per_key, dim=0) + values_list = ( + torch.split(values, length_per_key) + if need_indices and values is not None + else None + ) + + key_indices = defaultdict(list) + for i, key in enumerate(embedding_names): + key_indices[key].append(i) + for key, indices in key_indices.items(): + # combines outputs in correct order for CW sharding + indices = ( + _permute_indices(indices, features_to_permute_indices[key]) + if features_to_permute_indices and key in features_to_permute_indices + else indices + ) + ret[key] = JaggedTensor( + lengths=lengths_tuple[indices[0]], + values=( + embeddings_list[indices[0]] + if len(indices) == 1 + else torch.cat([embeddings_list[i] for i in indices], dim=1) + ), + weights=( + values_list[indices[0]] + if need_indices and values_list is not None + else None + ), + ) + return ret + + +def construct_jagged_tensors_inference( + embeddings: torch.Tensor, + lengths: torch.Tensor, + values: torch.Tensor, + embedding_names: List[str], + need_indices: bool = False, + features_to_permute_indices: Optional[Dict[str, List[int]]] = None, + reverse_indices: Optional[torch.Tensor] = None, + remove_padding: bool = False, +) -> Dict[str, JaggedTensor]: + with record_function("## construct_jagged_tensors_inference ##"): + if reverse_indices is not None: + embeddings = torch.index_select( + embeddings, 0, reverse_indices.to(torch.int32) + ) + elif remove_padding: + embeddings = _slice_1d_tensor(embeddings, 0, lengths.sum().item()) + + ret: Dict[str, JaggedTensor] = {} + + length_per_key: List[int] = _fx_to_list(torch.sum(lengths, dim=1)) + + lengths_tuple = torch.unbind(lengths, dim=0) + + embeddings_list = torch.split(embeddings, length_per_key, dim=0) + values_list = torch.split(values, length_per_key) if need_indices else None + + key_indices = defaultdict(list) + for i, key in enumerate(embedding_names): + key_indices[key].append(i) + for key, indices in key_indices.items(): + # combines outputs in correct order for CW sharding + indices = ( + _permute_indices(indices, features_to_permute_indices[key]) + if features_to_permute_indices and key in features_to_permute_indices + else indices + ) + ret[key] = JaggedTensor( + lengths=lengths_tuple[indices[0]], + values=( + embeddings_list[indices[0]] + if len(indices) == 1 + else torch.cat([embeddings_list[i] for i in indices], dim=1) + ), + # pyre-ignore + weights=values_list[indices[0]] if need_indices else None, + ) + return ret + + +def _permute_indices(indices: List[int], permute: List[int]) -> List[int]: + permuted_indices = [0] * len(indices) + for i, permuted_index in enumerate(permute): + permuted_indices[i] = indices[permuted_index] + return permuted_indices + + +@torch.fx.wrap +def jagged_index_select_with_empty( + values: torch.Tensor, + ids: torch.Tensor, + offsets: torch.Tensor, + output_offsets: torch.Tensor, +) -> torch.Tensor: + if ids.size()[0] == 0: + return torch.empty(0, device=values.device, dtype=values.dtype) + output_values = torch.ops.fbgemm.jagged_index_select_2d_forward_v2( + values.flatten().unsqueeze(-1), + ids, + offsets.long(), + output_offsets.long(), + ) + return output_values + + +def deterministic_dedup(ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + To remove race condition in conflict update, remove duplicated IDs. Only the last existence of duplicated ID will be kept. + Return sorted unique ids and the position of the last existence + """ + sorted_id_values, sorted_id_indices = ids.sort() + sorted_unique_ids, sorted_unique_inverses = sorted_id_values.unique_consecutive( + return_counts=False, + return_inverse=True, + ) + last_existence_index = torch.scatter_reduce( + input=torch.zeros_like(sorted_unique_ids, dtype=torch.int64), + dim=0, + index=sorted_unique_inverses, + src=sorted_id_indices, + reduce="amax", + ) + + return sorted_unique_ids.view(-1), last_existence_index.flatten() + + +def reset_module_states_post_sharding( + module: torch.nn.Module, +) -> None: + """ + Reset the module states post sharding. + Involves clearing cached tensors if they exist + from unsharded version. + """ + + # Clear Cache for TorchRec modules that have cache. Normally would happen in sharding + # but cached modules might not be part of the TorchRec modules being sharded. + # For example, necessary for KTRegroupAsDict correctness, + for submod in module.modules(): + if isinstance(submod, CacheMixin): + submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/optim/__init__.py b/torchrec/optim/__init__.py index 1ac91a517..2b09f5258 100644 --- a/torchrec/optim/__init__.py +++ b/torchrec/optim/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Optimizers Torchrec contains a special optimizer called KeyedOptimizer. KeyedOptimizer exposes the state_dict with meaningful keys- it enables loading both diff --git a/torchrec/optim/apply_optimizer_in_backward.py b/torchrec/optim/apply_optimizer_in_backward.py index 89140f83d..99ba72df5 100644 --- a/torchrec/optim/apply_optimizer_in_backward.py +++ b/torchrec/optim/apply_optimizer_in_backward.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Dict, Iterable, Type +from warnings import warn import torch @@ -16,6 +19,8 @@ def apply_optimizer_in_backward( optimizer_kwargs: Dict[str, Any], ) -> None: """ + NOTE: This API is deprecated. Please use Pytorch Distributed's _apply_optimizer_in_backward instead. + Upon backwards(), parameters will fire the corresponding optimizer Each parameter will have the optimizer_class and optimizer_kwargs attached to _optimizer and _optimizer_kwargs. @@ -41,40 +46,14 @@ def apply_optimizer_in_backward( >> torch.optim.SGD, {"lr": .02} """ - def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: - # acc_grad creates a new node in the auto_grad graph that comes after - # the parameter node. In this node's backwards we call the overlapped - # optimizer.step(). We cannot do the backwards hook on param because - # the gradients are not fully ready by then. - # for more details see https://github.com/pytorch/pytorch/issues/76464 - acc_grad = param.view_as(param).grad_fn.next_functions[0][0] - optimizer = optimizer_class([param], **optimizer_kwargs) - - if hasattr(param, "_acc_grad") and hasattr(param, "_overlapped_optimizer"): - raise ValueError( - # pyre-ignore - f"{param} already has {param._overlapped_optimizer} applied as an overlapped optimizer. Cannot apply again" - ) - - # The grad accumulator is a weak ref, so we need to keep it - # alive until the Tensor is alive. - # Store it on the module to avoid uncollectable ref-cycle - # pyre-ignore - param._acc_grad = acc_grad - param._overlapped_optimizer = optimizer - - # pyre-ignore - param._optimizer_class = optimizer_class - # pyre-ignore - param._optimizer_kwargs = optimizer_kwargs - - # pyre-ignore - def optimizer_hook(*_unused) -> None: - # pyre-ignore - param._overlapped_optimizer.step() - param.grad = None - - param._acc_grad.register_hook(optimizer_hook) - - for param in params: - _apply_optimizer_in_backward_to_param(param) + from torch.distributed.optim import _apply_optimizer_in_backward + + warn( + "This API is deprecated. Please use Pytorch Distributed's _apply_optimizer_in_backward API instead.", + DeprecationWarning, + ) + _apply_optimizer_in_backward( + optimizer_class=optimizer_class, + params=params, + optimizer_kwargs=optimizer_kwargs, + ) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 681ad5b4e..bd916c6f5 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -5,12 +5,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict +import logging +from collections import defaultdict from enum import Enum, unique -from typing import Any, List +from typing import Any, cast, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist +from torch.distributed._tensor.api import DTensor + from torchrec.optim.keyed import KeyedOptimizer, OptimizerWrapper +logger: logging.Logger = logging.getLogger() + +log_grad_norm: bool = False +use_64bit_grad_norm: bool = False + @unique class GradientClipping(Enum): @@ -27,6 +38,11 @@ class GradientClippingOptimizer(OptimizerWrapper): optimizer (KeyedOptimizer): optimizer to wrap clipping (GradientClipping): how to clip gradients max_gradient (float): max value for clipping + norm_type (float or str): type of the used p-norm. Can be ``'inf'`` for infinity norm. + enable_global_grad_clip (bool): whether to enable global gradient clipping. + param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]], optional): Mapping of parameters + to process groups. Used for global gradient clipping in n-D model parallelism case. + Defaults to None, local gradient clipping is used. """ def __init__( @@ -34,20 +50,233 @@ def __init__( optimizer: KeyedOptimizer, clipping: GradientClipping = GradientClipping.NONE, max_gradient: float = 0.1, + norm_type: Union[float, str] = 2.0, + enable_global_grad_clip: bool = False, + param_to_pgs: Optional[ + Dict[torch.nn.Parameter, List[dist.ProcessGroup]] + ] = None, ) -> None: super().__init__(optimizer) self._clipping = clipping self._max_gradient = max_gradient + self._norm_type = norm_type + self._check_meta: bool = True + self._enable_global_grad_clip = enable_global_grad_clip + self._step_num = 0 + + # Group parameters by model parallelism process group if global clipping is enabled. + # Otherwise, all parameters are treated as replicated and will be clipped locally. + sharded_param_cnt = 0 + self._replicate_params: List[torch.Tensor] = [] + self._sharded_params: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]] = ( + defaultdict(list) + ) - self._params: List[torch.Tensor] = [] for param_group in self.param_groups: - self._params += list(param_group["params"]) + for param in param_group["params"]: + if not self._enable_global_grad_clip: + self._replicate_params.append(param) + continue + if param_to_pgs is None or len(param_to_pgs) == 0: + self._replicate_params.append(param) + continue + + # Group parameters by model parallelism process group. + if param in param_to_pgs and len(param_to_pgs[param]) != 0: + self._sharded_params[tuple(param_to_pgs[param])].append(param) + sharded_param_cnt += 1 + else: + self._replicate_params.append(param) + logger.info( + f"Optimizer found {sharded_param_cnt} dist params and {len(self._replicate_params)} replicate params." + ) + + # Sanity check: this path is currently not used in any production. + if self._clipping == GradientClipping.VALUE: + if sharded_param_cnt > 0: + raise NotImplementedError( + "clip_grad_value_ for sharded parameters is not supported yet" + ) # pyre-ignore [2] def step(self, closure: Any = None) -> None: + if self._check_meta: + # skip gradient clipping and early return + if any(t.device.type == "meta" for t in self._replicate_params): + super().step(closure) + return + if any( + t.device.type == "meta" + for params in self._sharded_params.values() + for t in params + ): + super().step(closure) + return + self._check_meta = False + if self._clipping == GradientClipping.NORM: - torch.nn.utils.clip_grad_norm_(self._params, self._max_gradient) + # No sharded parameters, local gradient clipping == global gradient clipping + if len(self._sharded_params) == 0: + replicate_params = [ + p._local_tensor if isinstance(p, DTensor) else p + for p in self._replicate_params + ] + torch.nn.utils.clip_grad_norm_( + replicate_params, + self._max_gradient, + norm_type=self._norm_type, + ) + else: + self.clip_grad_norm_() + elif self._clipping == GradientClipping.VALUE: - torch.nn.utils.clip_grad_value_(self._params, self._max_gradient) + torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient) super().step(closure) + self._step_num += 1 + + @torch.no_grad() + def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: + """Clip the gradient norm of all parameters.""" + max_norm = self._max_gradient + norm_type = float(self._norm_type) + all_grads = [] + total_grad_norm = None + + # Process distributed parameters and gradients + for pgs, dist_params in self._sharded_params.items(): + sharded_grads = [ + p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad + for p in dist_params + if p.grad is not None and p.grad.numel() > 0 + ] + if len(sharded_grads) == 0: + continue + all_grads.extend(sharded_grads) + + sharded_grad_norm = _batch_cal_norm( + sharded_grads, + max_norm, + norm_type, + pgs, + ) + total_grad_norm = ( + sharded_grad_norm + if total_grad_norm is None + else ( + torch.maximum(total_grad_norm, sharded_grad_norm) + if self._norm_type == torch.inf + else total_grad_norm + sharded_grad_norm + ) + ) + + square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0 + + # Process replicated parameters and gradients + if self._replicate_params: + replicated_grads = [ + p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad + for p in self._replicate_params + if p.grad is not None and p.grad.numel() > 0 + ] + all_grads.extend(replicated_grads) + + replicated_grad_norm = _batch_cal_norm( + replicated_grads, + max_norm, + norm_type, + None, + ) + total_grad_norm = ( + replicated_grad_norm + if total_grad_norm is None + else ( + torch.maximum(total_grad_norm, replicated_grad_norm) + if self._norm_type == torch.inf + else total_grad_norm + replicated_grad_norm + ) + ) + square_replicated_grad_norm = replicated_grad_norm + else: + square_replicated_grad_norm = 0 + + global log_grad_norm + if log_grad_norm: + if total_grad_norm is not None and self._norm_type != torch.inf: + # pyre-ignore[58] + grad_norm = total_grad_norm ** (1.0 / norm_type) + else: + grad_norm = 0 + + rank = dist.get_rank() + logger.info( + f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" + ) + + # Aggregation + if total_grad_norm is None: + return + + if self._norm_type != torch.inf: + # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. + total_grad_norm = total_grad_norm ** (1.0 / norm_type) + # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. + clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + torch._foreach_mul_(all_grads, clip_coef_clamped) + return total_grad_norm + + +def _batch_cal_norm( + grad_list: List[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + process_groups: Optional[Tuple[dist.ProcessGroup]] = None, +) -> torch.Tensor: + """Helper function that calculates the norm of a list of gradients in batches. If process_groups + are passed in, the norm will be aggregated across all ranks in the process group. + """ + + global use_64bit_grad_norm + if use_64bit_grad_norm: + grad_norms = torch.linalg.vector_norm( + torch.stack(torch._foreach_norm(grad_list, norm_type, dtype=torch.float64)), + norm_type, + ) + else: + grad_norms = torch.linalg.vector_norm( + torch.stack(torch._foreach_norm(grad_list, norm_type)), + norm_type, + ) + + if norm_type == torch.inf: + if process_groups is not None: + for pg in process_groups: + dist.all_reduce(grad_norms, op=dist.ReduceOp.MAX, group=pg) + else: + grad_norms = grad_norms**norm_type + if process_groups is not None: + for pg in process_groups: + dist.all_reduce(grad_norms, group=pg) + + if use_64bit_grad_norm: + grad_norms = grad_norms.to(torch.float32) + + return grad_norms + + +def _dedup_to_base_tensors(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """ + This is a performance optimization specific to FSDP2. Each gradient tensor + of the same FSDP module share the same base tensor, so for the total norm + computation, we can directly use the base tensor to reduce the number of + tensors to compute norm over. + """ + seen_base_tensors = set() + base_tensors = [] + for tensor in tensors: + base_tensor = tensor._base if tensor._base is not None else tensor + if base_tensor not in seen_base_tensors: + seen_base_tensors.add(base_tensor) + base_tensors.append(base_tensor) + return base_tensors diff --git a/torchrec/optim/fused.py b/torchrec/optim/fused.py index 45ea10f67..91607e30b 100644 --- a/torchrec/optim/fused.py +++ b/torchrec/optim/fused.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc from typing import Any @@ -20,17 +22,31 @@ class FusedOptimizer(KeyedOptimizer, abc.ABC): @abc.abstractmethod # pyre-ignore [2] - def step(self, closure: Any = None) -> None: - ... + def step(self, closure: Any = None) -> None: ... @abc.abstractmethod - def zero_grad(self, set_to_none: bool = False) -> None: - ... + def zero_grad(self, set_to_none: bool = False) -> None: ... def __repr__(self) -> str: return optim.Optimizer.__repr__(self) +class EmptyFusedOptimizer(FusedOptimizer): + """ + Fused Optimizer class with no-op step and no parameters to optimize over + """ + + def __init__(self) -> None: + super().__init__({}, {}, {}) + + # pyre-ignore + def step(self, closure: Any = None) -> None: + pass + + def zero_grad(self, set_to_none: bool = False) -> None: + pass + + class FusedOptimizerModule(abc.ABC): """ Module, which does weight update during backward pass. @@ -38,5 +54,4 @@ class FusedOptimizerModule(abc.ABC): @property @abc.abstractmethod - def fused_optimizer(self) -> KeyedOptimizer: - ... + def fused_optimizer(self) -> KeyedOptimizer: ... diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index 2fcdb31f9..2f6b75d5f 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -5,6 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import json from copy import deepcopy from typing import ( Any, @@ -14,15 +17,17 @@ List, Mapping, Optional, + OrderedDict, Set, Tuple, Union, ) import torch -from torch import optim -from torchrec.distributed.types import ShardedTensor +from torch import optim +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.tensor import DTensor OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer] @@ -51,7 +56,16 @@ def __init__( param_groups: Collection[Mapping[str, Any]], ) -> None: torch._C._log_api_usage_once(f"torchrec.optim.{self.__class__.__name__}") - # pyre-ignore [4] + + # TODO: remove these and call super().__init__() + # super().__init__ calls add_param_group, which we've explicitly marked as not implemented. + # However, we need to ensure that all Optimizer member variables are created. + # pyre-ignore + self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() + # pyre-ignore + self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() + + # pyre-ignore self.state: Mapping[Any, Any] = state self.param_groups: Collection[Mapping[str, Any]] = param_groups self.params = params @@ -66,20 +80,101 @@ def __init__( ) ) + @staticmethod + def _extract_state_dict_content( + input_dict: Dict[str, Any], + ) -> Dict[str, Any]: + """Converts nested dictionary with objects with state dict functionality. + + Args: + input_dict (Dict[str, Any]): Nested dictionary containing objects with + state dict functionality. + + Output: + output_dict (Dict[str, Any]): Nested dictionary where the terminal values + cannot have state dict functionality. + + """ + result = {} + for k, v in input_dict.items(): + if isinstance(v, dict): + result[k] = KeyedOptimizer._extract_state_dict_content(v) + elif hasattr(v, "state_dict") and callable(v.state_dict): + result[k] = v.state_dict() + else: + result[k] = v + return result + + @staticmethod + def _update_param_state_dict_object( + current_param_state_dict: Dict[str, Any], + param_state_dict_to_load: Dict[str, Any], + parent_keys: List[Union[str, int, float, bool, None]], + ) -> None: + # Import at function level to avoid circular dependency. + from torchrec.distributed.shards_wrapper import LocalShardsWrapper + + for k, v in current_param_state_dict.items(): + new_v = param_state_dict_to_load[k] + parent_keys.append(k) + + if isinstance(v, dict): + KeyedOptimizer._update_param_state_dict_object( + v, + new_v, + parent_keys, + ) + elif hasattr(v, "load_state_dict") and callable(v.load_state_dict): + v.load_state_dict(new_v) + elif isinstance(v, ShardedTensor): + assert isinstance(new_v, ShardedTensor) + num_shards = len(v.local_shards()) + num_new_shards = len(new_v.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}" + ) + for shard, new_shard in zip(v.local_shards(), new_v.local_shards()): + shard.tensor.detach().copy_(new_shard.tensor) + elif isinstance(v, DTensor): + assert isinstance(new_v, DTensor) + if isinstance(v.to_local(), LocalShardsWrapper): + assert isinstance(new_v.to_local(), LocalShardsWrapper) + num_shards = len(v.to_local().local_shards()) # pyre-ignore[16] + num_new_shards = len(new_v.to_local().local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}" + ) + for shard, new_shard in zip( + v.to_local().local_shards(), new_v.to_local().local_shards() + ): + shard.detach().copy_(new_shard) + else: + assert isinstance(new_v.to_local(), torch.Tensor) + v.detach().copy_(new_v) + elif isinstance(v, torch.Tensor): + v.detach().copy_(new_v) + else: + current_param_state_dict[k] = deepcopy(new_v) + def state_dict(self) -> Dict[str, Any]: """ Returned state and param_groups will contain parameter keys instead of parameter indices in torch.Optimizer. This allows for advanced functionality like optimizer re-sharding to be implemented. + + Can also handle classes and supported data structures that follow the PyTorch stateful + protocol. """ - state = self.state param_groups = self.param_groups params = self.params param_to_key = {param: key for key, param in params.items()} ret_state = { - param_to_key[param]: state_val for param, state_val in state.items() + param_to_key[param]: self._extract_state_dict_content(param_state) + for param, param_state in self.state.items() } ret_groups = [] @@ -134,30 +229,12 @@ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: raise ValueError( f"Different state size: {len(state[param])} vs {len(new_state[param_key])}" ) - for state_key, state_val in state[param].items(): - if state_key not in new_state[param_key]: - raise ValueError( - f"State key {state_key} not found for param {param_key}" - ) - new_state_val = new_state[param_key][state_key] - if isinstance(state_val, ShardedTensor): - assert isinstance(new_state_val, ShardedTensor) - num_shards = len(state_val.local_shards()) - num_new_shards = len(new_state_val.local_shards()) - if num_shards != num_new_shards: - raise ValueError( - f"Different number of shards {num_shards} vs {num_new_shards} for {param_key}/{state_key}" - ) - for shard, new_shard in zip( - state_val.local_shards(), new_state_val.local_shards() - ): - shard.tensor.detach().copy_(new_shard.tensor) - elif isinstance(state_val, torch.Tensor): - assert isinstance(new_state_val, torch.Tensor) - state_val.detach().copy_(new_state_val) - else: - state[param][state_key] = deepcopy(new_state_val) + KeyedOptimizer._update_param_state_dict_object( + current_param_state_dict=state[param], + param_state_dict_to_load=new_state[param_key], + parent_keys=[param_key], + ) # Load param_groups. if self.defaults["_save_param_groups"]: @@ -220,7 +297,6 @@ def init_state( and key in sparse_grad_parameter_names ): t = t.to_sparse() - # pyre-fixme[8, 9, 19] param.grad = torch.autograd.Variable(t) self.step(closure=None) @@ -267,10 +343,17 @@ def __init__( raise ValueError(f"Duplicate param key {new_param}") all_keys.add(new_param) + # pyre-ignore + self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() + # pyre-ignore + self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() + + self._patch_step_function() + def __repr__(self) -> str: ret = [] - for _, opt in self._optims: - ret.append(opt.__repr__()) + for key, opt in self._optims: + ret.append(f"{key}: {opt.__repr__()}") return ",".join(ret) def zero_grad(self, set_to_none: bool = False) -> None: @@ -324,6 +407,18 @@ def save_param_groups(self, save: bool) -> None: for _, opt in self._optims: opt.save_param_groups(save) + def set_optimizer_step(self, step: int) -> None: + for _, opt in self._optims: + if hasattr(opt, "set_optimizer_step"): + # pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `set_optimizer_step`. + opt.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + for _, opt in self._optims: + if hasattr(opt, "update_hyper_parameters"): + # pyre-ignore [16]. + opt.update_hyper_parameters(params_dict) + class KeyedOptimizerWrapper(KeyedOptimizer): """ @@ -347,6 +442,16 @@ def zero_grad(self, set_to_none: bool = False) -> None: def step(self, closure: Any = None) -> None: self._optimizer.step(closure=closure) + def set_optimizer_step(self, step: int) -> None: + if hasattr(self._optimizer, "set_optimizer_step"): + # pyre-ignore [16]. + self._optimizer.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict) + class OptimizerWrapper(KeyedOptimizer): """ @@ -394,3 +499,13 @@ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: def save_param_groups(self, save: bool) -> None: self._optimizer.save_param_groups(save) + + def set_optimizer_step(self, step: int) -> None: + if hasattr(self._optimizer, "set_optimizer_step"): + # pyre-ignore [16]. + self._optimizer.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict) diff --git a/torchrec/optim/optimizers.py b/torchrec/optim/optimizers.py index f5c9aa152..10ffcdb2b 100644 --- a/torchrec/optim/optimizers.py +++ b/torchrec/optim/optimizers.py @@ -5,15 +5,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 -from typing import Iterable +from typing import Iterable, Iterator, Tuple import torch +from torch import nn from torch.optim.optimizer import Optimizer +def in_backward_optimizer_filter( + named_parameters: Iterator[Tuple[str, nn.Parameter]], include: bool = False +) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Filters named_parameters for whether they are or or not params that use + the in_backward_optimizer. + Note: This only supports the in_backward_optimizer from PT-D's API. + The torchrec's equivalent API is deprecated and is not supported. + Args: + named_parameters(Iterator[Tuple[str, nn.Parameter]]): named_parameters + include(bool): If true, only yields params with in_backward_optimizer. If false, returns the outside set + Defaults to include params that are not in_backward (False) + """ + for fqn, param in named_parameters: + if hasattr(param, "_in_backward_optimizers") == include: + yield fqn, param + + class SGD(Optimizer): r""" Placeholder for SGD. This optimizer will not functionally run. diff --git a/torchrec/optim/rowwise_adagrad.py b/torchrec/optim/rowwise_adagrad.py index ee188af72..6712b115e 100644 --- a/torchrec/optim/rowwise_adagrad.py +++ b/torchrec/optim/rowwise_adagrad.py @@ -5,8 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + #!/usr/bin/env python3 +import logging from typing import Any, Dict, Iterable, List import torch @@ -14,6 +17,8 @@ from torch.optim.optimizer import Optimizer +logger: logging.Logger = logging.getLogger(__name__) + class RowWiseAdagrad(Optimizer): r"""Implements Row wise Adagrad algorithm. This is an extension of the Adagrad algorithm @@ -85,8 +90,6 @@ def __init__( ) state["sum"] = ( # pyre-fixme[28]: Unexpected keyword argument `axis`. - # pyre-fixme[6]: For 2nd param expected `Union[bool, float, - # int]` but got `complex`. torch.full_like(p, init_value, memory_format=torch.preserve_format) .mean(axis=1) .view(-1, 1) @@ -172,7 +175,7 @@ def adagrad( See :class:`~torch.optim.Adagrad` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" ) @@ -202,8 +205,15 @@ def _single_tensor_adagrad( eps: float, maximize: bool, ) -> None: + if weight_decay != 0 and len(state_steps) > 0 and state_steps[0].item() < 1.0: + logger.warning( + "Note that the weight decay mode of this optimizer may produce " + "different results compared to the one by FBGEMM TBE. This is " + "due to FBGEMM TBE rowwise adagrad is sparse, and will only " + "update the optimizer states if that row has nonzero gradients." + ) - for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps): + for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): if grad.is_sparse: raise RuntimeError("RowWise adagrad cannot be used with sparse gradients") # update step @@ -211,14 +221,12 @@ def _single_tensor_adagrad( step = step_t.item() grad = grad if not maximize else -grad - row_wise_grad = grad.mean(axis=1).view(-1, 1) if weight_decay != 0: - grad = grad.add(param, alpha=weight_decay) - row_wise_grad = grad.add(param, alpha=weight_decay) + + state_sum += grad.pow(2).mean(axis=1).view(-1, 1) + std = state_sum.sqrt().add_(eps) clr = lr / (1 + (step - 1) * lr_decay) - state_sum.addcmul_(row_wise_grad, row_wise_grad, value=1) - std = state_sum.sqrt().add_(eps) - param.addcdiv_(row_wise_grad, std, value=-clr) + param.addcdiv_(grad, std, value=-clr) diff --git a/torchrec/optim/test_utils/__init__.py b/torchrec/optim/test_utils/__init__.py index 87ab1a2cc..e9c2599d3 100644 --- a/torchrec/optim/test_utils/__init__.py +++ b/torchrec/optim/test_utils/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any from torchrec.optim.keyed import KeyedOptimizer diff --git a/torchrec/optim/tests/test_apply_optimizer_in_backward.py b/torchrec/optim/tests/test_apply_optimizer_in_backward.py deleted file mode 100644 index cee1c2bfb..000000000 --- a/torchrec/optim/tests/test_apply_optimizer_in_backward.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import unittest - -import torch -from torchrec import EmbeddingBagCollection, EmbeddingBagConfig, KeyedJaggedTensor -from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward - - -class ApplyOverlappedOptimizerTest(unittest.TestCase): - def test_apply_optimizer_in_backward(self) -> None: - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="t1", embedding_dim=4, num_embeddings=2, feature_names=["f1"] - ), - EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=2, feature_names=["f2"] - ), - ] - ) - - apply_optimizer_in_backward( - torch.optim.SGD, - ebc.embedding_bags["t1"].parameters(), - optimizer_kwargs={"lr": 1.0}, - ) - - apply_optimizer_in_backward( - torch.optim.SGD, - ebc.embedding_bags["t2"].parameters(), - optimizer_kwargs={"lr": 2.0}, - ) - - ebc.load_state_dict( - { - "embedding_bags.t1.weight": torch.FloatTensor( - [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]] - ), - "embedding_bags.t2.weight": torch.FloatTensor( - [[10.0, 10.0, 10.0, 10.0], [12.0, 12.0, 12.0, 12.0]] - ), - } - ) - - # 0 1 2 <-- batch - # f1 [0,1] None [0] - # f2 [0,1] [1] [0] - # ^ - # feature - kjt = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 0, 0, 1, 1, 0]), - lengths=torch.tensor([2, 0, 1, 2, 1, 1]), - ) - - kt_out = ebc(kjt).to_dict() - stack = [] - for _key, val in kt_out.items(): - stack.append(val) - torch.stack(stack).sum().backward() - - t1_weight = next(ebc.embedding_bags["t1"].parameters()) - t2_weight = next(ebc.embedding_bags["t2"].parameters()) - - self.assertIsNone(t1_weight.grad) - self.assertIsNone(t2_weight.grad) - - self.assertTrue(hasattr(t1_weight, "_optimizer_class")) - self.assertEqual(t1_weight._optimizer_class, torch.optim.SGD) - self.assertTrue(hasattr(t1_weight, "_optimizer_kwargs")) - self.assertEqual(t1_weight._optimizer_kwargs, {"lr": 1.0}) - - self.assertTrue(hasattr(t2_weight, "_optimizer_class")) - self.assertEqual(t2_weight._optimizer_class, torch.optim.SGD) - self.assertTrue(hasattr(t2_weight, "_optimizer_kwargs")) - self.assertEqual(t2_weight._optimizer_kwargs, {"lr": 2.0}) - - expected_state_dict = { - "embedding_bags.t1.weight": torch.FloatTensor( - [[-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0]] - ), - "embedding_bags.t2.weight": torch.FloatTensor( - [[6.0, 6.0, 6.0, 6.0], [8.0, 8.0, 8.0, 8.0]] - ), - } - - for key, state in ebc.state_dict().items(): - self.assertIn(key, expected_state_dict) - torch.testing.assert_close(state, expected_state_dict[key]) diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index f41c19606..f26fd8884 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest +from unittest.mock import MagicMock, patch import torch from torch.autograd import Variable @@ -16,7 +19,6 @@ class TestGradientClippingOptimizer(unittest.TestCase): def test_clip_all_gradients_norm(self) -> None: # Clip all gradients to zero - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -28,15 +30,14 @@ def test_clip_all_gradients_norm(self) -> None: ) gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([1.0, 2.0]) gradient_clipping_optimizer.step() + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.0, 0.0]))) def test_clip_no_gradients_norm(self) -> None: # gradients are too small to be clipped - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -48,15 +49,14 @@ def test_clip_no_gradients_norm(self) -> None: ) gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([0.5, 0.5]) gradient_clipping_optimizer.step() + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.5, 0.5]))) def test_clip_partial_gradients_norm(self) -> None: # test partial clipping - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -69,20 +69,18 @@ def test_clip_partial_gradients_norm(self) -> None: gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([2.0, 4.0]) gradient_clipping_optimizer.step() norm = 2.0**2 + 4.0**2 expected_grad = torch.tensor([2.0, 4.0]) * norm ** (-0.5) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_1.grad, expected_grad)) def test_clip_partial_gradients_norm_multi_params(self) -> None: # test partial clipping max_gradient = 2.0 - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_2 = Variable(torch.tensor([2.0, 4.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -99,7 +97,6 @@ def test_clip_partial_gradients_norm_multi_params(self) -> None: gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([2.0, 4.0]) param_2.grad = torch.tensor([4.0, 8.0]) @@ -113,12 +110,13 @@ def test_clip_partial_gradients_norm_multi_params(self) -> None: print(param_1.grad, param_2.grad, expected_grad_1, expected_grad_2) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_1.grad, expected_grad_1)) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_2.grad, expected_grad_2)) def test_clip_all_gradients_value(self) -> None: # Clip all gradients to zero - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -130,15 +128,14 @@ def test_clip_all_gradients_value(self) -> None: ) gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([1.0, 2.0]) gradient_clipping_optimizer.step() + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.0, 0.0]))) def test_clip_no_gradients_value(self) -> None: # gradients are too small to be clipped - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -150,15 +147,14 @@ def test_clip_no_gradients_value(self) -> None: ) gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([0.5, 0.5]) gradient_clipping_optimizer.step() + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.5, 0.5]))) def test_clip_gradients_value(self) -> None: # test partial clipping - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -171,20 +167,18 @@ def test_clip_gradients_value(self) -> None: gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([2.0, 4.0]) gradient_clipping_optimizer.step() expected_grad = torch.tensor([1.0, 1.0]) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_1.grad, expected_grad)) def test_clip_partial_gradients_value_multi_params(self) -> None: # test partial clipping max_gradient = 2.0 - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) - # pyre-fixme[28]: Unexpected keyword argument `requires_grad`. param_2 = Variable(torch.tensor([2.0, 4.0]), requires_grad=True) keyed_optimizer = DummyKeyedOptimizer( @@ -201,7 +195,6 @@ def test_clip_partial_gradients_value_multi_params(self) -> None: gradient_clipping_optimizer.zero_grad() - # pyre-fixme[16]: `Variable` has no attribute `grad`. param_1.grad = torch.tensor([2.0, 4.0]) param_2.grad = torch.tensor([4.0, 8.0]) @@ -210,5 +203,29 @@ def test_clip_partial_gradients_value_multi_params(self) -> None: expected_grad_1 = torch.tensor([2.0, 2.0]) expected_grad_2 = torch.tensor([2.0, 2.0]) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_1.grad, expected_grad_1)) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[Tensor]`. self.assertTrue(torch.allclose(param_2.grad, expected_grad_2)) + + @patch("torch.nn.utils.clip_grad_norm_") + def test_clip_no_gradients_norm_meta_device( + self, mock_clip_grad_norm: MagicMock + ) -> None: + # Clip all gradients to zero + param_1 = Variable( + torch.tensor([1.0, 2.0], device=torch.device("meta")), requires_grad=True + ) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=0.0, clipping=GradientClipping.NORM + ) + + gradient_clipping_optimizer.zero_grad() + gradient_clipping_optimizer.step() + + mock_clip_grad_norm.assert_not_called() diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index ff0efd614..4be370e6a 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import io import os import unittest @@ -23,6 +25,20 @@ from torchrec.test_utils import get_free_port +class DummyOptimizerModule: + def __init__( + self, + tensor: torch.Tensor, + ) -> None: + self.tensor = tensor + + def state_dict(self) -> Dict[str, Any]: + return {"tensor": self.tensor} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.tensor.detach().copy_(state_dict["tensor"]) + + class TestKeyedOptimizer(unittest.TestCase): def _assert_state_dict_equals( self, dict1: Dict[str, Any], dict2: Dict[str, Any] @@ -36,6 +52,14 @@ def _assert_state_dict_equals( dict1["state"]["param_1"]["tensor"], dict2["state"]["param_1"]["tensor"], ) + torch.testing.assert_close( + dict1["state"]["param_1"]["nested_dictionary"]["tensor"], + dict2["state"]["param_1"]["nested_dictionary"]["tensor"], + ) + torch.testing.assert_close( + dict1["state"]["param_1"]["optimizer_module"]["tensor"], + dict2["state"]["param_1"]["optimizer_module"]["tensor"], + ) torch.testing.assert_close( dict1["state"]["param_1"]["sharded_tensor"].local_shards()[0].tensor, @@ -49,11 +73,8 @@ def test_load_state_dict(self) -> None: # Set up example KeyedOptimizer. param_1_t, param_2_t = torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1, param_2 = Variable(param_1_t), Variable(param_2_t) keyed_optimizer = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_1": param_1, "param_2": param_2}, { param_1: { @@ -67,6 +88,10 @@ def test_load_state_dict(self) -> None: (4,), fill_value=1.0, ), + "nested_dictionary": { + "tensor": torch.tensor([7.0, 8.0]), + }, + "optimizer_module": DummyOptimizerModule(torch.tensor([9.0, 10.0])), }, param_2: {"two": 2.0}, }, @@ -96,6 +121,12 @@ def test_load_state_dict(self) -> None: (4,), fill_value=1.0, ), + "nested_dictionary": { + "tensor": torch.tensor([7.0, 8.0]), + }, + "optimizer_module": { + "tensor": torch.tensor([9.0, 10.0]), + }, }, "param_2": {"two": 2.0}, } @@ -132,6 +163,14 @@ def test_load_state_dict(self) -> None: fill_value=10.0, ) # pyre-ignore [6] + expected_state_dict["state"]["param_1"]["nested_dictionary"]["tensor"] = ( + torch.tensor([70.0, 80.0]) + ) + # pyre-ignore [6] + expected_state_dict["state"]["param_1"]["optimizer_module"]["tensor"] = ( + torch.tensor([90.0, 100.0]) + ) + # pyre-ignore [6] expected_state_dict["param_groups"][0]["param_group_val_0"] = 8.0 # pyre-ignore [6] expected_state_dict["param_groups"][1]["param_group_val_1"] = 9.0 @@ -145,22 +184,19 @@ def test_load_state_dict(self) -> None: def test_non_param_state_key(self) -> None: with self.assertRaisesRegex(ValueError, "All state keys must be params."): param_1_t = torch.tensor([1.0, 2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1 = Variable(param_1_t) KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_1": param_1}, {param_1: 1.0, "non_param_state_key": 2.0}, [{"params": [param_1], "param_group_val_0": 3.0}], ) - def test_init_state(self) -> None: + def test_init_state_with_momentum(self) -> None: dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float)) sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float)) opt = KeyedOptimizerWrapper( {"dense": dense, "sparse": sparse}, - lambda params: torch.optim.SGD(params, lr=0.1), + lambda params: torch.optim.SGD(params, lr=0.1, momentum=0.1), ) opt.init_state({"sparse"}) @@ -172,6 +208,24 @@ def test_init_state(self) -> None: self.assertTrue(sparse.grad.is_sparse) self.assertTrue("momentum_buffer" in opt.state_dict()["state"]["sparse"]) + def test_init_state_no_momentum(self) -> None: + dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float)) + sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float)) + opt = KeyedOptimizerWrapper( + {"dense": dense, "sparse": sparse}, + lambda params: torch.optim.SGD(params, lr=0.1), + ) + opt.init_state({"sparse"}) + + self.assertTrue(dense.grad is not None) + self.assertFalse(dense.grad.is_sparse) + + self.assertTrue(sparse.grad is not None) + self.assertTrue(sparse.grad.is_sparse) + + self.assertTrue("state" in opt.state_dict()) + self.assertFalse(opt.state_dict()["state"]) + def test_pickle(self) -> None: dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float)) sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float)) @@ -184,7 +238,7 @@ def test_pickle(self) -> None: bytesIO = io.BytesIO() torch.save(opt, bytesIO) bytesIO.seek(0) - reload_opt = torch.load(bytesIO) + reload_opt = torch.load(bytesIO, weights_only=False) for k in reload_opt.state_dict(): self.assertEqual( @@ -197,11 +251,8 @@ class TestCombinedOptimizer(unittest.TestCase): def test_pickle(self) -> None: # Set up example KeyedOptimizer 1. param_1_t = torch.tensor([1.0, 2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1 = Variable(param_1_t) keyed_optimizer_1 = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_1": param_1}, {param_1: {"one": 1.0}}, [{"params": [param_1], "param_group_val_0": 2.0}], @@ -209,11 +260,8 @@ def test_pickle(self) -> None: # Set up example KeyedOptimizer 2. param_2_t = torch.tensor([-1.0, -2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_2 = Variable(param_2_t) keyed_optimizer_2 = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_2": param_2}, {param_2: {"two": -1.0}}, [{"params": [param_2], "param_group_val_0": -2.0}], @@ -226,7 +274,7 @@ def test_pickle(self) -> None: bytesIO = io.BytesIO() torch.save(combined_optimizer, bytesIO) bytesIO.seek(0) - reload_combined_optimizer = torch.load(bytesIO) + reload_combined_optimizer = torch.load(bytesIO, weights_only=False) for k in reload_combined_optimizer.state_dict(): self.assertEqual( @@ -237,11 +285,8 @@ def test_pickle(self) -> None: def test_load_state_dict(self) -> None: # Set up example KeyedOptimizer 1. param_1_t = torch.tensor([1.0, 2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1 = Variable(param_1_t) keyed_optimizer_1 = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_1": param_1}, {param_1: {"one": 1.0}}, [{"params": [param_1], "param_group_val_0": 2.0}], @@ -249,11 +294,8 @@ def test_load_state_dict(self) -> None: # Set up example KeyedOptimizer 2. param_2_t = torch.tensor([-1.0, -2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_2 = Variable(param_2_t) keyed_optimizer_2 = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_2": param_2}, {param_2: {"two": -1.0}}, [{"params": [param_2], "param_group_val_0": -2.0}], @@ -284,11 +326,8 @@ def test_load_state_dict(self) -> None: class TestOptimizerWrapper(unittest.TestCase): def test_load_state_dict(self) -> None: param_1_t = torch.tensor([1.0, 2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1 = Variable(param_1_t) keyed_optimizer = KeyedOptimizer( - # pyre-fixme[6]: For 1st param expected `Mapping[str, - # Union[ShardedTensor, Tensor]]` but got `Mapping[str, Variable]`. {"param_1": param_1}, {param_1: {"one": 1.0}}, [{"params": [param_1], "param_group_val_0": 2.0}], diff --git a/torchrec/optim/tests/test_optim.py b/torchrec/optim/tests/test_optim.py new file mode 100644 index 000000000..3c3bb6487 --- /dev/null +++ b/torchrec/optim/tests/test_optim.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.optimizers import in_backward_optimizer_filter + + +class TestInBackwardOptimizerFilter(unittest.TestCase): + def test_in_backward_optimizer_filter(self) -> None: + ebc = EmbeddingBagCollection( + tables=[ + EmbeddingBagConfig( + name="t1", embedding_dim=4, num_embeddings=2, feature_names=["f1"] + ), + EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=2, feature_names=["f2"] + ), + ] + ) + apply_optimizer_in_backward( + torch.optim.SGD, + ebc.embedding_bags["t1"].parameters(), + optimizer_kwargs={"lr": 1.0}, + ) + in_backward_params = dict( + in_backward_optimizer_filter(ebc.named_parameters(), include=True) + ) + non_in_backward_params = dict( + in_backward_optimizer_filter(ebc.named_parameters(), include=False) + ) + assert set(in_backward_params.keys()) == {"embedding_bags.t1.weight"} + assert set(non_in_backward_params.keys()) == {"embedding_bags.t2.weight"} diff --git a/torchrec/optim/tests/test_rowwise_adagrad.py b/torchrec/optim/tests/test_rowwise_adagrad.py index 14a7f6582..3a221b412 100644 --- a/torchrec/optim/tests/test_rowwise_adagrad.py +++ b/torchrec/optim/tests/test_rowwise_adagrad.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest @@ -14,17 +16,24 @@ class RowWiseAdagradTest(unittest.TestCase): def test_optim(self) -> None: - embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) opt = torchrec.optim.RowWiseAdagrad(embedding_bag.parameters()) index, offsets = torch.tensor([0, 3]), torch.tensor([0, 1]) embedding_bag_out = embedding_bag(index, offsets) opt.zero_grad() embedding_bag_out.sum().backward() + opt.step() def test_optim_equivalence(self) -> None: # If rows are initialized to be the same and uniform, then RowWiseAdagrad and canonical Adagrad are identical - rowwise_embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) - embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) + rowwise_embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) state_dict = { "weight": torch.Tensor( [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]] diff --git a/torchrec/optim/tests/test_warmup.py b/torchrec/optim/tests/test_warmup.py index cec6ffb49..446f6f895 100644 --- a/torchrec/optim/tests/test_warmup.py +++ b/torchrec/optim/tests/test_warmup.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from collections import defaultdict from typing import Any @@ -28,7 +30,6 @@ class TestWarmupOptimizer(unittest.TestCase): def test_load_state_dict(self) -> None: def get_optimizer() -> WarmupOptimizer: param_1_t = torch.tensor([1.0, 2.0]) - # pyre-fixme[19]: Expected 0 positional arguments. param_1 = Variable(param_1_t) keyed_optimizer = DummyKeyedOptimizer( {"param_1": param_1}, defaultdict(dict), [{"params": [param_1]}] diff --git a/torchrec/optim/warmup.py b/torchrec/optim/warmup.py index 50579bf11..beebfec01 100644 --- a/torchrec/optim/warmup.py +++ b/torchrec/optim/warmup.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging import math from dataclasses import dataclass @@ -25,6 +27,7 @@ class WarmupPolicy(Enum): POLY = "poly" STEP = "step" INVSQRT = "inv_sqrt" # inverse square root + COSINE_ANNEALING_WARM_RESTARTS = "cosine_annealing_warm_restarts" @dataclass @@ -38,6 +41,7 @@ class WarmupStage: # also used as stepsize in step decay # default to 1 if not set to value > 0 decay_iters: int = -1 + sgdr_period: int = 1 def _lr_stages(stages: List[WarmupStage]) -> List[WarmupStage]: @@ -72,6 +76,16 @@ def _get_multiplier(stage: WarmupStage, iter: int) -> float: multiplier = math.pow(stage.value, iter // stage.decay_iters) elif stage.policy == WarmupPolicy.INVSQRT: multiplier = 1.0 / math.sqrt(iter) + elif stage.policy == WarmupPolicy.COSINE_ANNEALING_WARM_RESTARTS: + # SGDR: Stochastic Gradient Descent with Warm Restarts: + # https://arxiv.org/abs/1608.03983. + # Forgo period multiplier T_mult, as lr multiplier is a stateless + # computation without knowledge of previous period size. + eta_min = stage.value + t_0 = stage.sgdr_period + t_cur = iter % t_0 + cos_iter = 0.5 * (1 + math.cos(math.pi * t_cur / t_0)) + multiplier = eta_min + (1.0 - eta_min) * cos_iter return multiplier * stage.lr_scale diff --git a/examples/torcharrow/__init__.py b/torchrec/pt2/__init__.py similarity index 81% rename from examples/torcharrow/__init__.py rename to torchrec/pt2/__init__.py index 4c72865f7..7704fda5f 100644 --- a/examples/torcharrow/__init__.py +++ b/torchrec/pt2/__init__.py @@ -5,4 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import dataloader # noqa +# pyre-strict + +# __init__ for python module packaging diff --git a/torchrec/pt2/checks.py b/torchrec/pt2/checks.py new file mode 100644 index 000000000..76626a9f8 --- /dev/null +++ b/torchrec/pt2/checks.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List + +import torch + +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + +USE_TORCHDYNAMO_COMPILING_PATH: bool = False + + +def set_use_torchdynamo_compiling_path(val: bool) -> None: + global USE_TORCHDYNAMO_COMPILING_PATH + USE_TORCHDYNAMO_COMPILING_PATH = val + + +def get_use_torchdynamo_compiling_path() -> bool: + global USE_TORCHDYNAMO_COMPILING_PATH + return USE_TORCHDYNAMO_COMPILING_PATH + + +try: + if torch.jit.is_scripting(): + raise Exception() + + from torch.compiler import ( + is_compiling as is_compiler_compiling, + is_dynamo_compiling as _is_torchdynamo_compiling, + ) + + def is_torchdynamo_compiling() -> bool: + if torch.jit.is_scripting(): + return False + + # Can not use global variable here, as it is not supported in TorchScript + # (It parses full method src even there is a guard torch.jit.is_scripting()) + return get_use_torchdynamo_compiling_path() or _is_torchdynamo_compiling() + + def is_non_strict_exporting() -> bool: + return not is_torchdynamo_compiling() and is_compiler_compiling() + +except Exception: + # BC for torch versions without compiler and torch deploy path + def is_torchdynamo_compiling() -> bool: + return False + + def is_non_strict_exporting() -> bool: + return False + + +def is_pt2_compiling() -> bool: + return is_torchdynamo_compiling() or is_compiler_compiling() + + +def pt2_checks_tensor_slice( + tensor: torch.Tensor, start_offset: int, end_offset: int, dim: int = 0 +) -> None: + if torch.jit.is_scripting() or not is_pt2_compiling(): + return + + torch._check_is_size(start_offset) + torch._check_is_size(end_offset) + torch._check_is_size(end_offset - start_offset) + torch._check(start_offset <= tensor.size(dim)) + torch._check(end_offset <= tensor.size(dim)) + torch._check(end_offset >= start_offset) + + +def pt2_checks_all_is_size(x: List[int]) -> List[int]: + if torch.jit.is_scripting() or not is_pt2_compiling(): + return x + + for i in x: + torch._check_is_size(i) + return x + + +def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting() or not is_pt2_compiling(): + return x + + for i in range(x.dim()): + torch._check(x.size(i) > 0) + return x + + +def pt2_guard_size_oblivious(x: bool) -> bool: + if torch.jit.is_scripting() or not is_pt2_compiling(): + return x + + return guard_size_oblivious(x) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py new file mode 100644 index 000000000..55accff68 --- /dev/null +++ b/torchrec/pt2/utils.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import functools +from typing import Any, Callable + +import torch +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +""" +Prepares KJT for PT2 tracing. + +KJT contains caching/lazy compute logic. +For tracing we need to drop all caches to have all compute logic in the graph. +This is done by recreation of KJT with minimal specified data. + +convert_to_vb - If True recreates KJT as Variable Batch. +""" + + +def kjt_for_pt2_tracing( + kjt: KeyedJaggedTensor, + convert_to_vb: bool = False, +) -> KeyedJaggedTensor: + # Breaking dependency cycle + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + is_vb = kjt.variable_stride_per_key() + if convert_to_vb and not is_vb: + stride: int = kjt.stride() + n = len(kjt.keys()) + + inverse_indices_tensor = ( + torch.arange(stride).expand(n, stride).contiguous().to(device=kjt.device()) + ) + torch._dynamo.decorators.mark_static(inverse_indices_tensor, 0) + torch._dynamo.decorators.mark_static(inverse_indices_tensor, 1) + + lengths = kjt.lengths().long() + # We can mark static lengths dimension as we have fixed batch_size, but using VB path for tracing + torch._dynamo.decorators.mark_static(lengths, 0) + values = kjt.values().long() + torch._dynamo.decorators.mark_unbacked(values, 0) + + return KeyedJaggedTensor( + keys=kjt.keys(), + values=values, + lengths=lengths, + weights=kjt.weights_or_none(), + stride_per_key_per_rank=[[stride]] * n, + inverse_indices=(kjt.keys(), inverse_indices_tensor), + ) + + inverse_indices = None + stride = None + + if is_vb: + inverse_indices = kjt.inverse_indices_or_none() + + if inverse_indices is not None: + inverse_indices_tensor = inverse_indices[1] + torch._dynamo.decorators.mark_static(inverse_indices_tensor, 0) + torch._dynamo.decorators.mark_static(inverse_indices_tensor, 1) + + lengths = kjt.lengths().long() + + stride = kjt.stride() + + values = kjt.values().long() + torch._dynamo.decorators.mark_unbacked(values, 0) + weights = kjt.weights_or_none() + if weights is not None: + torch._dynamo.decorators.mark_unbacked(weights, 0) + + return KeyedJaggedTensor( + keys=kjt.keys(), + values=values, + lengths=lengths, + weights=weights, + stride=stride if not is_vb else None, + stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, + inverse_indices=inverse_indices, + ) + + +# pyre-ignore +def default_pipeline_input_transformer(inp): + for attr_name in ["id_list_features", "id_score_list_features"]: + if hasattr(inp, attr_name): + attr = getattr(inp, attr_name) + if isinstance(attr, KeyedJaggedTensor): + setattr(inp, attr_name, kjt_for_pt2_tracing(attr)) + return inp + + +def register_fake_classes() -> None: + @torch._library.register_fake_class("fbgemm::AtomicCounter") + class FakeAtomicCounter: + def __init__(self, counter_): + self.counter_ = counter_ + + @classmethod + def __obj_unflatten__(cls, flat_obj): + return cls(**dict(flat_obj)) + + def increment(self) -> int: + self.counter_ += 1 + return self.counter_ + + def decrement(self) -> int: + self.counter_ -= 1 + return self.counter_ + + def reset(self): + self.counter_ = 0 + + def get(self) -> int: + return self.counter_ + + def set(self, val): + self.counter_ = val + + @torch._library.register_fake_class("fbgemm::TensorQueue") + class FakeTensorQueue: + def __init__(self, queue, init_tensor): + self.queue = queue + self.init_tensor = init_tensor + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + if len(self.queue) == 0: + return self.init_tensor + return self.queue.pop(0) + + def top(self): + if len(self.queue) == 0: + return self.init_tensor + return self.queue[0] + + def size(self): + return len(self.queue) + + +def deregister_fake_classes() -> None: + torch._library.fake_class_registry.deregister_fake_class("fbgemm::AtomicCounter") + torch._library.fake_class_registry.deregister_fake_class("fbgemm::TensorQueue") + + +# pyre-ignore[24] +def pt2_compile_callable(f: Callable) -> Callable: + """ + This method is used to decorate the update and compute methods of a metric computation class. + If the metric computation class has enable_pt2_compile attribute set to True, + then the update and compute methods will be compiled using torch.compile. + """ + + @functools.wraps(f) + # pyre-ignore[3] + def inner_forward( + ref: torch.nn.Module, + *args: Any, + **kwargs: Any, + ) -> Any: + if hasattr(ref, "enable_pt2_compile") and ref.enable_pt2_compile: + pt2_compiled_attr_name = f"_{f.__name__}_pt2_compiled" + if not hasattr(ref, pt2_compiled_attr_name): + setattr(ref, pt2_compiled_attr_name, torch.compile(f)) + return getattr(ref, pt2_compiled_attr_name)(ref, *args, **kwargs) + + return f(ref, *args, **kwargs) + + return inner_forward diff --git a/torchrec/quant/__init__.py b/torchrec/quant/__init__.py index 0fa242550..0923240f7 100644 --- a/torchrec/quant/__init__.py +++ b/torchrec/quant/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Quantization Torchrec provides a quantized version of EmbeddingBagCollection for inference. diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index db522ddc0..81b9b8bfc 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -5,20 +5,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import copy import itertools -from collections import defaultdict, OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from collections import defaultdict +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import torch import torch.nn as nn -from fbgemm_gpu.split_table_batched_embeddings_ops import ( +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( EmbeddingLocation, IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, ) from torch import Tensor +from torchrec.distributed.utils import none_throws from torchrec.modules.embedding_configs import ( + BaseEmbeddingConfig, DATA_TYPE_NUM_BITS, data_type_to_sparse_type, DataType, @@ -27,6 +42,7 @@ EmbeddingConfig, pooling_type_to_pooling_mode, PoolingType, + QuantConfig, ) from torchrec.modules.embedding_modules import ( EmbeddingBagCollection as OriginalEmbeddingBagCollection, @@ -35,9 +51,30 @@ EmbeddingCollectionInterface, get_embedding_names_by_table, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection +from torchrec.modules.fp_embedding_modules import ( + FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, +) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ManagedCollisionCollection +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("len") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -46,28 +83,152 @@ # OSS try: - import fbgemm_gpu # @manual # noqa + pass except ImportError: pass +MODULE_ATTR_REGISTER_TBES_BOOL: str = "__register_tbes_in_named_modules" + +MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = ( + "__quant_state_dict_split_scale_bias" +) + +MODULE_ATTR_ROW_ALIGNMENT_INT: str = "__register_row_alignment_in_named_modules" + +MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT: str = ( + "__emb_name_to_num_rows_post_pruning" +) + +MODULE_ATTR_REMOVE_STBE_PADDING_BOOL: str = "__remove_stbe_padding" + +MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING: str = ( + "__use_unflattened_lengths_for_batching" +) + +MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output" + +MODULE_ATTR_CACHE_FEATURES_ORDER: str = "__cache_features_order" + +DEFAULT_ROW_ALIGNMENT = 16 + + +@torch.fx.wrap +def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: + return feature.lengths() + + +@torch.fx.wrap +def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return feature.keys() + + +@torch.fx.wrap +def _permute_kjt( + features: KeyedJaggedTensor, + permute_order: List[int], + permute_order_tensor: Optional[Tensor] = None, +) -> KeyedJaggedTensor: + if permute_order == list(range(len(permute_order))): + return features.flatten_lengths() + return features.permute(permute_order, permute_order_tensor) + + +@torch.fx.wrap +def _cat_embeddings(embeddings: List[Tensor]) -> Tensor: + return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1) + + +@torch.fx.wrap +def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor: + """ + Unflatten lengths tensor from [F * B] to [F, B]. + """ + return lengths.view(num_features, -1) + + +def for_each_module_of_type_do( + module: nn.Module, + module_types: List[Type[torch.nn.Module]], + op: Callable[[torch.nn.Module], None], +) -> None: + for m in module.modules(): + if any([isinstance(m, t) for t in module_types]): + op(m) + + +def quant_prep_enable_quant_state_dict_split_scale_bias(module: nn.Module) -> None: + setattr(module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, True) + + +def quant_prep_enable_quant_state_dict_split_scale_bias_for_types( + module: nn.Module, module_types: List[Type[torch.nn.Module]] +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, True), + ) + + +def quant_prep_enable_register_tbes( + module: nn.Module, module_types: List[Type[torch.nn.Module]] +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_REGISTER_TBES_BOOL, True), + ) + + +def quant_prep_customize_row_alignment( + module: nn.Module, module_types: List[Type[torch.nn.Module]], row_alignment: int +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_ROW_ALIGNMENT_INT, row_alignment), + ) + + +def quant_prep_enable_cache_features_order( + module: nn.Module, module_types: List[Type[torch.nn.Module]] +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_CACHE_FEATURES_ORDER, True), + ) + def quantize_state_dict( module: nn.Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], - data_type: DataType, + table_name_to_data_type: Dict[str, DataType], + table_name_to_num_embeddings_post_pruning: Optional[Dict[str, int]] = None, ) -> torch.device: device = torch.device("cpu") + if not table_name_to_num_embeddings_post_pruning: + table_name_to_num_embeddings_post_pruning = {} + for key, tensor in module.state_dict().items(): # Extract table name from state dict key. # e.g. ebc.embedding_bags.t1.weight splits = key.split(".") assert splits[-1] == "weight" table_name = splits[-2] + data_type = table_name_to_data_type[table_name] + num_rows = tensor.shape[0] + + if table_name in table_name_to_num_embeddings_post_pruning: + num_rows = table_name_to_num_embeddings_post_pruning[table_name] + device = tensor.device num_bits = DATA_TYPE_NUM_BITS[data_type] + if tensor.is_meta: quant_weight = torch.empty( - (tensor.shape[0], (tensor.shape[1] * num_bits) // 8), + (num_rows, (tensor.shape[1] * num_bits) // 8), device="meta", dtype=torch.uint8, ) @@ -77,18 +238,24 @@ def quantize_state_dict( or data_type == DataType.INT2 ): scale_shift = torch.empty( - (tensor.shape[0], 4), + (num_rows, 4), device="meta", dtype=torch.uint8, ) else: scale_shift = None else: + if num_rows != tensor.shape[0]: + tensor = tensor[:num_rows, :] if tensor.dtype == torch.float or tensor.dtype == torch.float16: if data_type == DataType.FP16: if tensor.dtype == torch.float: tensor = tensor.half() quant_res = tensor.view(torch.uint8) + elif data_type == DataType.FP32: + if tensor.dtype == torch.float16: + tensor = tensor.float() + quant_res = tensor.view(torch.uint8) else: quant_res = ( torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( @@ -112,66 +279,64 @@ def quantize_state_dict( return device -class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): - """ - EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). - This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops and provides - table batching. - - It processes sparse data in the form of KeyedJaggedTensor - with values of the form [F X B X L] - F: features (keys) - B: batch size - L: Length of sparse features (jagged) - - and outputs a KeyedTensor with values of the form [B * (F * D)] - where - F: features (keys) - D: each feature's (key's) embedding dimension - B: batch size +def _get_device(module: nn.Module) -> torch.device: + device = torch.device("cpu") - Args: - table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]): map of tables to quantized weights - embedding_configs (List[EmbeddingBagConfig]): list of embedding tables - is_weighted: (bool): whether input KeyedJaggedTensor is weighted - device: (Optional[torch.device]): default compute device + for _, tensor in module.state_dict().items(): + device = tensor.device + break + return device - Call Args: - features: KeyedJaggedTensor, - Returns: - KeyedTensor +def _update_embedding_configs( + embedding_configs: Sequence[BaseEmbeddingConfig], + quant_config: Union[QuantConfig, torch.quantization.QConfig], + tables_to_rows_post_pruning: Optional[Dict[str, int]] = None, +) -> None: + per_table_weight_dtype = ( + quant_config.per_table_weight_dtype + if isinstance(quant_config, QuantConfig) and quant_config.per_table_weight_dtype + else {} + ) + for config in embedding_configs: + config.data_type = dtype_to_data_type( + per_table_weight_dtype[config.name] + if config.name in per_table_weight_dtype + else quant_config.weight().dtype + ) - Example:: + if tables_to_rows_post_pruning and config.name in tables_to_rows_post_pruning: + config.num_embeddings_post_pruning = tables_to_rows_post_pruning[ + config.name + ] - table_0 = EmbeddingBagConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - table_1 = EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] - ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - - # 0 1 2 <-- batch - # "f1" [0,1] None [2] - # "f2" [3] [4] [5,6,7] - # ^ - # feature - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - ebc.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8 - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), - ) +@torch.fx.wrap +def _fx_trec_unwrap_kjt( + kjt: KeyedJaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forced conversions to support TBE + CPU - int32 or int64, offsets dtype must match + GPU - int32 only, offsets dtype must match + """ + indices = kjt.values() + offsets = kjt.offsets() + if kjt.device().type == "cpu": + return indices, offsets.type(dtype=indices.dtype) + else: + return indices.int(), offsets.int() - qebc = QuantEmbeddingBagCollection.from_float(ebc) - quantized_embeddings = qebc(features) + +class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): + """ + This class represents a reimplemented version of the EmbeddingBagCollection + class found in `torchrec/modules/embedding_modules.py`. + However, it is quantized for lower precision. + It relies on fbgemm quantized ops and provides table batching. + + For more details, including examples, please refer to + `torchrec/modules/embedding_modules.py` """ def __init__( @@ -183,6 +348,10 @@ def __init__( table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() self._is_weighted = is_weighted @@ -190,7 +359,10 @@ def __init__( self._key_to_tables: Dict[ Tuple[PoolingType, DataType], List[EmbeddingBagConfig] ] = defaultdict(list) + self._feature_names: List[str] = [] + self._feature_splits: List[int] = [] self._length_per_key: List[int] = [] + self._features_order: Optional[List[int]] = None # Registering in a List instead of ModuleList because we want don't want them to be auto-registered. # Their states will be modified via self.embedding_bags self._emb_modules: List[nn.Module] = [] @@ -199,63 +371,78 @@ def __init__( self._table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] ] = None + self.row_alignment = row_alignment + self._kjt_to_jt_dict = ComputeKJTToJTDict() table_names = set() for table in self._embedding_bag_configs: if table.name in table_names: raise ValueError(f"Duplicate table name {table.name}") table_names.add(table.name) - self._length_per_key.extend( - [table.embedding_dim] * len(table.feature_names) - ) - key = (table.pooling, table.data_type) - self._key_to_tables[key].append(table) - - self._sum_length_per_key: int = sum(self._length_per_key) + # pyre-ignore + self._key_to_tables[table.pooling].append(table) location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE ) - for key, emb_configs in self._key_to_tables.items(): - (pooling, data_type) = key + for pooling, emb_configs in self._key_to_tables.items(): embedding_specs = [] - weight_lists: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = ( - [] if table_name_to_quantized_weights else None - ) + weight_lists: Optional[ + List[Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = ([] if table_name_to_quantized_weights else None) feature_table_map: List[int] = [] for idx, table in enumerate(emb_configs): embedding_specs.append( ( table.name, - table.num_embeddings, + ( + table.num_embeddings_post_pruning + # TODO: Need to check if attribute exists for BC + if getattr(table, "num_embeddings_post_pruning", None) + is not None + else table.num_embeddings + ), table.embedding_dim, - data_type_to_sparse_type(data_type), + data_type_to_sparse_type(table.data_type), location, ) ) if table_name_to_quantized_weights: - # pyre-ignore - weight_lists.append(table_name_to_quantized_weights[table.name]) + none_throws(weight_lists).append( + table_name_to_quantized_weights[table.name] + ) feature_table_map.extend([idx] * table.num_features()) emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( embedding_specs=embedding_specs, + # pyre-ignore pooling_mode=pooling_type_to_pooling_mode(pooling), weight_lists=weight_lists, device=device, output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)), - row_alignment=16, + row_alignment=row_alignment, feature_table_map=feature_table_map, ) - if device != torch.device("meta") and weight_lists is None: + if weight_lists is None: emb_module.initialize_weights() self._emb_modules.append(emb_module) + for table in emb_configs: + self._feature_names.extend(table.feature_names) + self._feature_splits.append( + sum(table.num_features() for table in emb_configs) + ) + ordered_tables = list(itertools.chain(*self._key_to_tables.values())) self._embedding_names: List[str] = list( - itertools.chain(*get_embedding_names_by_table(self._embedding_bag_configs)) + itertools.chain(*get_embedding_names_by_table(ordered_tables)) ) + for table in ordered_tables: + self._length_per_key.extend( + [table.embedding_dim] * len(table.feature_names) + ) + # We map over the parameters from FBGEMM backed kernels to the canonical nn.EmbeddingBag # representation. This provides consistency between this class and the EmbeddingBagCollection # nn.Module API calls (state_dict, named_modules, etc) @@ -263,16 +450,45 @@ def __init__( for (_key, tables), emb_module in zip( self._key_to_tables.items(), self._emb_modules ): - for embedding_config, (weight, _) in zip( - tables, emb_module.split_embedding_weights(split_scale_shifts=False) + for embedding_config, (weight, qscale, qbias) in zip( + tables, + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=2 if quant_state_dict_split_scale_bias else 0 + ), ): self.embedding_bags[embedding_config.name] = torch.nn.Module() # register as a buffer so it's exposed in state_dict. + # TODO: register as param instead of buffer # however, since this is only needed for inference, we do not need to expose it as part of parameters. # Additionally, we cannot expose uint8 weights as parameters due to autograd restrictions. + + if embedding_config.data_type == DataType.INT4: + weight = UInt4Tensor(weight) + elif embedding_config.data_type == DataType.INT2: + weight = UInt2Tensor(weight) + self.embedding_bags[embedding_config.name].register_buffer( "weight", weight ) + if quant_state_dict_split_scale_bias: + self.embedding_bags[embedding_config.name].register_buffer( + "weight_qscale", qscale + ) + self.embedding_bags[embedding_config.name].register_buffer( + "weight_qbias", qbias + ) + + setattr( + self, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + quant_state_dict_split_scale_bias, + ) + setattr(self, MODULE_ATTR_REGISTER_TBES_BOOL, register_tbes) + self.register_tbes = register_tbes + if register_tbes: + self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -286,48 +502,56 @@ def forward( KeyedTensor """ - feature_dict = features.to_dict() embeddings = [] + kjt_keys = _get_kjt_keys(features) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order is None: + self._features_order = [kjt_keys.index(k) for k in self._feature_names] + if self._features_order: + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, + self._features_order, + getattr(self, "_features_order_tensor", None), + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) + kjts_per_key = kjt_permute.split(self._feature_splits) - # TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script. - # Once torchsccript is no longer a requirement, we should revisit this. - - for emb_op, (_key, tables) in zip( - self._emb_modules, self._key_to_tables.items() + for i, (emb_op, _) in enumerate( + zip(self._emb_modules, self._key_to_tables.keys()) ): - indices = [] - lengths = [] - offsets = [] - weights = [] - - for table in tables: - for feature in table.feature_names: - f = feature_dict[feature] - indices.append(f.values()) - lengths.append(f.lengths()) - if self._is_weighted: - weights.append(f.weights()) - - indices = torch.cat(indices) - lengths = torch.cat(lengths) - - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - if self._is_weighted: - weights = torch.cat(weights) + f = kjts_per_key[i] + indices, offsets = _fx_trec_unwrap_kjt(f) embeddings.append( + # Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module emb_op( - indices=indices.int(), - offsets=offsets.int(), - per_sample_weights=weights if self._is_weighted else None, + indices=indices, + offsets=offsets, + per_sample_weights=f.weights() if self._is_weighted else None, + ) + if self.register_tbes + else emb_op.forward( + indices=indices, + offsets=offsets, + per_sample_weights=f.weights() if self._is_weighted else None, ) ) - embeddings = torch.stack(embeddings).reshape(-1, self._sum_length_per_key) - return KeyedTensor( keys=self._embedding_names, - values=embeddings, + values=_cat_embeddings(embeddings), length_per_key=self._length_per_key, ) @@ -336,27 +560,50 @@ def _get_name(self) -> str: @classmethod def from_float( - cls, module: OriginalEmbeddingBagCollection + cls, + module: OriginalEmbeddingBagCollection, + use_precomputed_fake_quant: bool = False, ) -> "EmbeddingBagCollection": assert hasattr( module, "qconfig" ), "EmbeddingBagCollection input float module must have qconfig defined" - - # pyre-ignore [16] - data_type = dtype_to_data_type(module.qconfig.weight().dtype) + pruning_dict: Dict[str, int] = getattr( + module, MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT, {} + ) embedding_bag_configs = copy.deepcopy(module.embedding_bag_configs()) - for config in embedding_bag_configs: - config.data_type = data_type + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_bag_configs), + # pyre-fixme[6]: For 2nd argument expected `Union[QuantConfig, QConfig]` + # but got `Union[Module, Tensor]`. + module.qconfig, + pruning_dict, + ) table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {} - device = quantize_state_dict(module, table_name_to_quantized_weights, data_type) + device = quantize_state_dict( + module, + table_name_to_quantized_weights, + {table.name: table.data_type for table in embedding_bag_configs}, + pruning_dict, + ) return cls( embedding_bag_configs, module.is_weighted(), device=device, - # pyre-ignore [16] + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `activation`. output_dtype=module.qconfig.activation().dtype, table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def embedding_bag_configs( @@ -375,59 +622,116 @@ def device(self) -> torch.device: return self._device -class EmbeddingCollection(EmbeddingCollectionInterface, ModuleNoCopyMixin): - """ - EmbeddingCollection represents a collection of non-pooled embeddings. - - It processes sparse data in the form of `KeyedJaggedTensor` of the form [F X B X L] - where: +class FeatureProcessedEmbeddingBagCollection(EmbeddingBagCollection): + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool, + device: torch.device, + output_dtype: torch.dtype = torch.float, + table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + # feature processor is Optional only for the sake of the last position in constructor + # Enforcing it to be non-None, for None case EmbeddingBagCollection must be used. + feature_processor: Optional[FeatureProcessorsCollection] = None, + cache_features_order: bool = False, + ) -> None: + super().__init__( + tables, + is_weighted, + device, + output_dtype, + table_name_to_quantized_weights, + register_tbes, + quant_state_dict_split_scale_bias, + row_alignment, + cache_features_order, + ) + assert ( + feature_processor is not None + ), "Use EmbeddingBagCollection for no feature_processor" + self.feature_processor: FeatureProcessorsCollection = feature_processor - * F: features (keys) - * B: batch size - * L: length of sparse features (variable) + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + features = self.feature_processor(features) + return super().forward(features) - and outputs `Dict[feature (key), JaggedTensor]`. - Each `JaggedTensor` contains values of the form (B * L) X D - where: + def _get_name(self) -> str: + return "QuantFeatureProcessedEmbeddingBagCollection" - * B: batch size - * L: length of sparse features (jagged) - * D: each feature's (key's) embedding dimension and lengths are of the form L + @classmethod + # pyre-ignore + def from_float( + cls, + module: OriginalFeatureProcessedEmbeddingBagCollection, + use_precomputed_fake_quant: bool = False, + ) -> "FeatureProcessedEmbeddingBagCollection": + fp_ebc = module + ebc = module._embedding_bag_collection + qconfig = module.qconfig + assert hasattr( + module, "qconfig" + ), "FeatureProcessedEmbeddingBagCollection input float module must have qconfig defined" - Args: - tables (List[EmbeddingConfig]): list of embedding tables. - device (Optional[torch.device]): default compute device. - need_indices (bool): if we need to pass indices to the final lookup result dict + pruning_dict: Dict[str, int] = getattr( + module, MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT, {} + ) - Example:: + embedding_bag_configs = copy.deepcopy(ebc.embedding_bag_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_bag_configs), + # pyre-fixme[6]: For 2nd argument expected `Union[QuantConfig, QConfig]` + # but got `Union[Module, Tensor]`. + qconfig, + pruning_dict, + ) - e1_config = EmbeddingConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {} + device = quantize_state_dict( + ebc, + table_name_to_quantized_weights, + {table.name: table.data_type for table in embedding_bag_configs}, + pruning_dict, ) - e2_config = EmbeddingConfig( - name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + return cls( + embedding_bag_configs, + ebc.is_weighted(), + device=device, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `activation`. + output_dtype=qconfig.activation().dtype, + table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + ebc, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + ebc, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + # pyre-ignore + feature_processor=fp_ebc._feature_processors, + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) - ec = EmbeddingCollection(tables=[e1_config, e2_config]) - # 0 1 2 <-- batch - # 0 [0,1] None [2] - # 1 [3] [4] [5,6,7] - # ^ - # feature +class EmbeddingCollection(EmbeddingCollectionInterface, ModuleNoCopyMixin): + """ + This class represents a reimplemented version of the EmbeddingCollection + class found in `torchrec/modules/embedding_modules.py`. + However, it is quantized for lower precision. + It relies on fbgemm quantized ops and provides table batching. - features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - feature_embeddings = ec(features) - print(feature_embeddings['f2'].values()) - tensor([[-0.2050, 0.5478, 0.6054], - [ 0.7352, 0.3210, -3.0399], - [ 0.1279, -0.1756, -0.4130], - [ 0.7519, -0.4341, -0.0499], - [ 0.9329, -1.0697, -0.8095]], grad_fn=) + For more details, including examples, please refer to + `torchrec/modules/embedding_modules.py` """ def __init__( # noqa C901 @@ -439,62 +743,133 @@ def __init__( # noqa C901 table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() - self.embeddings: nn.ModuleList = nn.ModuleList() + self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = [] self._embedding_configs = tables self._embedding_dim: int = -1 self._need_indices: bool = need_indices self._output_dtype = output_dtype self._device = device + self.row_alignment = row_alignment + self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list) + self._feature_names: List[str] = [] + self._features_order: Optional[List[int]] = None + + self._table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = table_name_to_quantized_weights table_names = set() - for config in tables: - if config.name in table_names: - raise ValueError(f"Duplicate table name {config.name}") - table_names.add(config.name) + for table in self._embedding_configs: + if table.name in table_names: + raise ValueError(f"Duplicate table name {table.name}") + table_names.add(table.name) self._embedding_dim = ( - config.embedding_dim if self._embedding_dim < 0 else self._embedding_dim + table.embedding_dim if self._embedding_dim < 0 else self._embedding_dim ) - if self._embedding_dim != config.embedding_dim: + if self._embedding_dim != table.embedding_dim: raise ValueError( "All tables in a EmbeddingCollection are required to have same embedding dimension." + + f" Violating case: {table.name}'s embedding_dim {table.embedding_dim} !=" + + f" {self._embedding_dim}" ) - weight_lists: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = ( - [] if table_name_to_quantized_weights else None - ) - if table_name_to_quantized_weights: - # pyre-ignore - weight_lists.append(table_name_to_quantized_weights[config.name]) - emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ + key = table.data_type + self._key_to_tables[key].append(table) + self._feature_names.extend(table.feature_names) + self._feature_splits: List[int] = [] + for key, emb_configs in self._key_to_tables.items(): + data_type = key + embedding_specs = [] + weight_lists: Optional[ + List[Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = ([] if table_name_to_quantized_weights else None) + feature_table_map: List[int] = [] + for idx, table in enumerate(emb_configs): + embedding_specs.append( ( - "", - config.num_embeddings, - config.embedding_dim, - data_type_to_sparse_type(config.data_type), - EmbeddingLocation.HOST - if device.type == "cpu" - else EmbeddingLocation.DEVICE, + table.name, + table.num_embeddings, + table.embedding_dim, + data_type_to_sparse_type(data_type), + ( + EmbeddingLocation.HOST + if device.type == "cpu" + else EmbeddingLocation.DEVICE + ), + ) + ) + if table_name_to_quantized_weights: + none_throws(weight_lists).append( + table_name_to_quantized_weights[table.name] ) - ], + feature_table_map.extend([idx] * table.num_features()) + emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=embedding_specs, pooling_mode=PoolingMode.NONE, weight_lists=weight_lists, device=device, output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)), - row_alignment=16, + row_alignment=row_alignment, + feature_table_map=feature_table_map, ) - if device != torch.device("meta") and weight_lists is None: + if weight_lists is None: emb_module.initialize_weights() + self._emb_modules.append(emb_module) + self._feature_splits.append( + sum(table.num_features() for table in emb_configs) + ) - self.embeddings.append(emb_module) + self.embeddings: nn.ModuleDict = nn.ModuleDict() + for (_key, tables), emb_module in zip( + self._key_to_tables.items(), self._emb_modules + ): + for embedding_config, (weight, qscale, qbias) in zip( + tables, + emb_module.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=2 if quant_state_dict_split_scale_bias else 0 + ), + ): + self.embeddings[embedding_config.name] = torch.nn.Module() + # register as a buffer so it's exposed in state_dict. + # TODO: register as param instead of buffer + # however, since this is only needed for inference, we do not need to expose it as part of parameters. + # Additionally, we cannot expose uint8 weights as parameters due to autograd restrictions. + if embedding_config.data_type == DataType.INT4: + weight = UInt4Tensor(weight) + elif embedding_config.data_type == DataType.INT2: + weight = UInt2Tensor(weight) + self.embeddings[embedding_config.name].register_buffer("weight", weight) + if quant_state_dict_split_scale_bias: + self.embeddings[embedding_config.name].register_buffer( + "weight_qscale", qscale + ) + self.embeddings[embedding_config.name].register_buffer( + "weight_qbias", qbias + ) - if not config.feature_names: - config.feature_names = [config.name] + self._embedding_names_by_batched_tables: Dict[DataType, List[str]] = { + key: list(itertools.chain(*get_embedding_names_by_table(table))) + for key, table in self._key_to_tables.items() + } self._embedding_names_by_table: List[List[str]] = get_embedding_names_by_table( - tables + self._embedding_configs ) + setattr( + self, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + quant_state_dict_split_scale_bias, + ) + setattr(self, MODULE_ATTR_REGISTER_TBES_BOOL, register_tbes) + self.register_tbes = register_tbes + if register_tbes: + self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -509,70 +884,104 @@ def forward( """ feature_embeddings: Dict[str, JaggedTensor] = {} - jt_dict: Dict[str, JaggedTensor] = features.to_dict() - for config, embedding_names, emb_module in zip( - self._embedding_configs, - self._embedding_names_by_table, - self.embeddings, - ): - for feature_name, embedding_name in zip( - config.feature_names, embedding_names - ): - f = jt_dict[feature_name] - values = f.values() - offsets = f.offsets() - lookup = emb_module( - indices=values.int(), - offsets=offsets.int(), - ) - feature_embeddings[embedding_name] = JaggedTensor( - values=lookup, - lengths=f.lengths(), - weights=f.values() if self.need_indices else None, - ) - return feature_embeddings + kjt_keys = _get_kjt_keys(features) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order is None: + self._features_order = [kjt_keys.index(k) for k in self._feature_names] + if self._features_order: + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, + self._features_order, + getattr(self, "_features_order_tensor", None), + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) + kjts_per_key = kjt_permute.split(self._feature_splits) - # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for emb_config, emb_module in zip( - self._embedding_configs, - self.embeddings, + for i, (emb_module, key) in enumerate( + zip(self._emb_modules, self._key_to_tables.keys()) ): - (weight, _) = emb_module.split_embedding_weights(split_scale_shifts=False)[ - 0 - ] - destination[prefix + f"embeddings.{emb_config.name}.weight"] = weight - return destination + f = kjts_per_key[i] + lengths = _get_feature_length(f) + indices, offsets = _fx_trec_unwrap_kjt(f) + embedding_names = self._embedding_names_by_batched_tables[key] + lookup = ( + emb_module(indices=indices, offsets=offsets) + if self.register_tbes + else emb_module.forward(indices=indices, offsets=offsets) + ) + if getattr(self, MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING, False): + lengths = _get_unflattened_lengths(lengths, len(embedding_names)) + lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) + else: + if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT, True): + lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) + lengths = _get_unflattened_lengths(lengths, len(embedding_names)) + jt = construct_jagged_tensors_inference( + embeddings=lookup, + lengths=lengths, + values=indices, + embedding_names=embedding_names, + need_indices=self.need_indices(), + remove_padding=getattr( + self, MODULE_ATTR_REMOVE_STBE_PADDING_BOOL, False + ), + ) + for embedding_name in embedding_names: + feature_embeddings[embedding_name] = jt[embedding_name] + return feature_embeddings @classmethod - def from_float(cls, module: OriginalEmbeddingCollection) -> "EmbeddingCollection": + def from_float( + cls, + module: OriginalEmbeddingCollection, + use_precomputed_fake_quant: bool = False, + ) -> "EmbeddingCollection": assert hasattr( module, "qconfig" ), "EmbeddingCollection input float module must have qconfig defined" - - # pyre-ignore [16] - data_type = dtype_to_data_type(module.qconfig.weight().dtype) - tables = copy.deepcopy(module.embedding_configs()) - for config in tables: - config.data_type = data_type - + embedding_configs = copy.deepcopy(module.embedding_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_configs), + # pyre-fixme[6]: For 2nd argument expected `Union[QuantConfig, QConfig]` + # but got `Union[Module, Tensor]`. + module.qconfig, + ) table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {} - device = quantize_state_dict(module, table_name_to_quantized_weights, data_type) - + device = quantize_state_dict( + module, + table_name_to_quantized_weights, + {table.name: table.data_type for table in embedding_configs}, + ) return cls( - tables, + embedding_configs, device=device, need_indices=module.need_indices(), + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `activation`. + output_dtype=module.qconfig.activation().dtype, table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def _get_name(self) -> str: @@ -596,3 +1005,153 @@ def output_dtype(self) -> torch.dtype: @property def device(self) -> torch.device: return self._device + + +class QuantManagedCollisionEmbeddingCollection(EmbeddingCollection): + """ + QuantManagedCollisionEmbeddingCollection represents a quantized EC module and a set of managed collision modules. + The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection. + + Args: + tables (List[EmbeddingConfig]): A list of EmbeddingConfig objects representing the embedding tables in the collection. + device (torch.device): The device on which the embedding collection will be allocated. + need_indices (bool, optional): Whether to return the indices along with the embeddings. Defaults to False. + output_dtype (torch.dtype, optional): The data type of the output embeddings. Defaults to torch.float. + table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]], optional): A dictionary mapping table names to their corresponding quantized weights. Defaults to None. + register_tbes (bool, optional): Whether to register the TBEs in the model. Defaults to False. + quant_state_dict_split_scale_bias (bool, optional): Whether to split the scale and bias parameters when saving the quantized state dict. Defaults to False. + row_alignment (int, optional): The alignment of rows in the quantized weights. Defaults to DEFAULT_ROW_ALIGNMENT. + managed_collision_collection (ManagedCollisionCollection, optional): The managed collision collection to use for managing collisions. Defaults to None. + return_remapped_features (bool, optional): Whether to return the remapped input features in addition to the embeddings. Defaults to False. + """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + need_indices: bool = False, + output_dtype: torch.dtype = torch.float, + table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + managed_collision_collection: Optional[ManagedCollisionCollection] = None, + return_remapped_features: bool = False, + cache_features_order: bool = False, + ) -> None: + super().__init__( + tables, + device, + need_indices, + output_dtype, + table_name_to_quantized_weights, + register_tbes, + quant_state_dict_split_scale_bias, + row_alignment, + cache_features_order, + ) + assert ( + managed_collision_collection + ), "Managed collision collection cannot be None" + self._managed_collision_collection: ManagedCollisionCollection = ( + managed_collision_collection + ) + self._return_remapped_features = return_remapped_features + + assert str(self.embedding_configs()) == str( + self._managed_collision_collection.embedding_configs() + ), "Embedding Collection and Managed Collision Collection must contain the same Embedding Configs" + + # Assuming quantized MCEC is used in inference only + for ( + managed_collision_module + ) in self._managed_collision_collection._managed_collision_modules.values(): + managed_collision_module.reset_inference_mode() + + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + + # pyre-ignore + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + features = self._managed_collision_collection(features) + + return (super().forward(features), features) + + def _get_name(self) -> str: + return "QuantManagedCollisionEmbeddingCollection" + + @classmethod + # pyre-ignore + def from_float( + cls, + module: OriginalManagedCollisionEmbeddingCollection, + return_remapped_features: bool = False, + ) -> "QuantManagedCollisionEmbeddingCollection": + mc_ec = module + ec = module._embedding_module + + # pyre-ignore[9] + qconfig: torch.quantization.QConfig = module.qconfig + assert hasattr( + module, "qconfig" + ), "QuantManagedCollisionEmbeddingCollection input float module must have qconfig defined" + + # pyre-ignore[29] + embedding_configs = copy.deepcopy(ec.embedding_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_configs), + qconfig, + ) + _update_embedding_configs( + mc_ec._managed_collision_collection._embedding_configs, + qconfig, + ) + + # pyre-ignore[9] + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] | None = ( + ec._table_name_to_quantized_weights + if hasattr(ec, "_table_name_to_quantized_weights") + else None + ) + device = _get_device(ec) + return cls( + embedding_configs, + device=device, + output_dtype=qconfig.activation().dtype, + table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + ec, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + ec, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + managed_collision_collection=mc_ec._managed_collision_collection, + return_remapped_features=mc_ec._return_remapped_features, + cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False), + ) diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index 7449cb441..e04c0bf11 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -5,26 +5,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from typing import List +from dataclasses import replace +from typing import Dict, List, Optional, Type import hypothesis.strategies as st + import torch from hypothesis import given, settings, Verbosity +from torchrec import inference as trec_infer +from torchrec.distributed.quant_embedding_kernel import _unwrap_kjt, _unwrap_kjt_for_cpu from torchrec.modules.embedding_configs import ( DataType, EmbeddingBagConfig, EmbeddingConfig, + PoolingType, + QuantConfig, ) from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingCollection, ) from torchrec.quant.embedding_modules import ( + _fx_trec_unwrap_kjt, + _get_batching_hinted_output, + _get_unflattened_lengths, EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, + MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING, + quant_prep_enable_quant_state_dict_split_scale_bias, +) +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, ) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class EmbeddingBagCollectionTest(unittest.TestCase): @@ -34,8 +52,9 @@ def _asserting_same_embeddings( pooled_embeddings_2: KeyedTensor, atol: float = 1e-08, ) -> None: - - self.assertEqual(pooled_embeddings_1.keys(), pooled_embeddings_2.keys()) + self.assertEqual( + set(pooled_embeddings_1.keys()), set(pooled_embeddings_2.keys()) + ) for key in pooled_embeddings_1.keys(): self.assertEqual( pooled_embeddings_1[key].shape, pooled_embeddings_2[key].shape @@ -54,31 +73,50 @@ def _test_ebc( features: KeyedJaggedTensor, quant_type: torch.dtype = torch.qint8, output_type: torch.dtype = torch.float, + quant_state_dict_split_scale_bias: bool = False, + per_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, ) -> None: ebc = EmbeddingBagCollection(tables=tables) + if quant_state_dict_split_scale_bias: + quant_prep_enable_quant_state_dict_split_scale_bias(ebc) embeddings = ebc(features) # test forward - # pyre-ignore [16] - ebc.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=output_type - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type), - ) + if not per_table_weight_dtype: + # pyre-fixme[16]: `EmbeddingBagCollection` has no attribute `qconfig`. + ebc.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=output_type + ), + weight=torch.quantization.PlaceholderObserver.with_args( + dtype=quant_type + ), + ) + else: + ebc.qconfig = QuantConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=output_type + ), + weight=torch.quantization.PlaceholderObserver.with_args( + dtype=quant_type + ), + per_table_weight_dtype=per_table_weight_dtype, + ) qebc = QuantEmbeddingBagCollection.from_float(ebc) quantized_embeddings = qebc(features) self.assertEqual(quantized_embeddings.values().dtype, output_type) - self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=1.0) + self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=0.1) # test state dict state_dict = ebc.state_dict() quantized_state_dict = qebc.state_dict() - self.assertEqual(state_dict.keys(), quantized_state_dict.keys()) + self.assertTrue( + set(state_dict.keys()).issubset(set(quantized_state_dict.keys())) + ) # pyre-fixme[56] @given( @@ -101,6 +139,13 @@ def _test_ebc( ] ), permute_order=st.booleans(), + quant_state_dict_split_scale_bias=st.booleans(), + per_table_weight_dtype=st.sampled_from( + [ + {"t1": torch.quint4x2, "t2": torch.qint8}, + {"t1": torch.qint8, "t2": torch.quint4x2}, + ] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_ebc( @@ -109,6 +154,8 @@ def test_ebc( quant_type: torch.dtype, output_type: torch.dtype, permute_order: bool, + quant_state_dict_split_scale_bias: bool, + per_table_weight_dtype: Dict[str, torch.dtype], ) -> None: eb1_config = EmbeddingBagConfig( name="t1", @@ -117,27 +164,62 @@ def test_ebc( feature_names=["f1"], data_type=data_type, ) - eb2_config = EmbeddingBagConfig( - name="t2", - embedding_dim=16, - num_embeddings=10, - feature_names=["f2"], - data_type=data_type, + eb1_mean_config = replace( + eb1_config, + name="t1_mean", + pooling=PoolingType.MEAN, + embedding_dim=32, ) + eb2_config = replace(eb1_config, name="t2", feature_names=["f2"]) features = ( KeyedJaggedTensor( keys=["f1", "f2"], - values=torch.as_tensor([0, 1]), - lengths=torch.as_tensor([1, 1]), + values=torch.as_tensor([0, 2, 1, 3]), + lengths=torch.as_tensor([1, 1, 2, 0]), ) if not permute_order else KeyedJaggedTensor( keys=["f2", "f1"], - values=torch.as_tensor([1, 0]), - lengths=torch.as_tensor([1, 1]), + values=torch.as_tensor([1, 3, 0, 2]), + lengths=torch.as_tensor([2, 0, 1, 1]), ) ) - self._test_ebc([eb1_config, eb2_config], features, quant_type, output_type) + # The key for grouping tables is (pooling, data_type). Test having a different + # key value in the middle. + self._test_ebc( + [eb1_config, eb1_mean_config, eb2_config], + features, + quant_type, + output_type, + quant_state_dict_split_scale_bias, + ) + + self._test_ebc( + [eb1_config, eb1_mean_config, eb2_config], + features, + quant_type, + output_type, + quant_state_dict_split_scale_bias, + per_table_weight_dtype, + ) + + def test_create_on_meta_device_without_providing_weights(self) -> None: + emb_bag = EmbeddingBagConfig( + name="t1", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1"], + ) + QuantEmbeddingBagCollection( + [emb_bag], is_weighted=False, device=torch.device("meta") + ) + emb = EmbeddingConfig( + name="t1", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1"], + ) + QuantEmbeddingCollection([emb], device=torch.device("meta")) def test_shared_tables(self) -> None: eb_config = EmbeddingBagConfig( @@ -225,7 +307,7 @@ def test_save_load_state_dict( ebc = EmbeddingBagCollection(tables=tables) # test forward - # pyre-ignore [16] + # pyre-fixme[16]: `EmbeddingBagCollection` has no attribute `qconfig`. ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=output_type @@ -324,9 +406,8 @@ def forward(self, kjt: KeyedJaggedTensor) -> torch.Tensor: tables = [eb1_config, eb2_config] ebc = EmbeddingBagCollection(tables=tables) - # test forward - # pyre-ignore [16] + # pyre-fixme[16]: `EmbeddingBagCollection` has no attribute `qconfig`. ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=output_type @@ -340,47 +421,142 @@ def forward(self, kjt: KeyedJaggedTensor) -> torch.Tensor: test_model.ebc = QuantEmbeddingBagCollection.from_float(ebc) state_dict = test_model.state_dict() - self.assertEqual(state_dict.keys(), before_quant_state_dict.keys()) + self.assertTrue( + set(before_quant_state_dict.keys()).issubset(set(state_dict.keys())) + ) test_model.load_state_dict(state_dict) + def test_trace_and_script(self) -> None: + data_type = DataType.FP16 + quant_type = torch.half + output_type = torch.half -class EmbeddingCollectionTest(unittest.TestCase): - def _test_ec( - self, tables: List[EmbeddingConfig], features: KeyedJaggedTensor - ) -> None: - eb = EmbeddingCollection(tables=tables) - - embeddings = eb(features) + eb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1"], + data_type=data_type, + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1"], + data_type=data_type, + ) - # test forward - # pyre-ignore [16] - eb.qconfig = torch.quantization.QConfig( + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + # pyre-fixme[16]: `EmbeddingBagCollection` has no attribute `qconfig`. + ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8, + dtype=output_type ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type), ) - qeb = QuantEmbeddingCollection.from_float(eb) - quantized_embeddings = qeb(features) + qebc = QuantEmbeddingBagCollection.from_float(ebc) + + from torchrec.fx import symbolic_trace + + gm = symbolic_trace(qebc, leaf_modules=[ComputeKJTToJTDict.__name__]) - self.assertEqual(embeddings.keys(), quantized_embeddings.keys()) + non_placeholder_nodes = [ + node for node in gm.graph.nodes if node.op != "placeholder" + ] + self.assertTrue( + len(non_placeholder_nodes) > 0, "Graph must have non-placeholder nodes" + ) + self.assertEqual( + non_placeholder_nodes[0].op, + "call_function", + f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead", + ) + self.assertEqual( + non_placeholder_nodes[0].name, + "_get_kjt_keys", + f"First non-placeholder node must be '_get_kjt_keys', got {non_placeholder_nodes[0].name} instead", + ) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.as_tensor([0, 1]), + lengths=torch.as_tensor([1, 1]), + ) + + original_out = qebc(features) + traced_out = gm(features) + + scripted_module = torch.jit.script(gm) + scripted_out = scripted_module(features) + + self.assertEqual(original_out.keys(), traced_out.keys()) + torch.testing.assert_close(original_out.values(), traced_out.values()) + self.assertEqual(original_out.offset_per_key(), traced_out.offset_per_key()) + + self.assertEqual(original_out.keys(), scripted_out.keys()) + torch.testing.assert_close(original_out.values(), scripted_out.values()) + self.assertEqual(original_out.offset_per_key(), scripted_out.offset_per_key()) + + +class EmbeddingCollectionTest(unittest.TestCase): + def _comp_ec_output( + self, + embeddings: Dict[str, JaggedTensor], + transformed_graph_embeddings: Dict[str, JaggedTensor], + atol: int = 1, + ) -> None: + self.assertEqual(embeddings.keys(), transformed_graph_embeddings.keys()) for key in embeddings.keys(): self.assertEqual( embeddings[key].values().size(), - quantized_embeddings[key].values().size(), + transformed_graph_embeddings[key].values().size(), ) self.assertTrue( torch.allclose( embeddings[key].values().cpu().float(), - quantized_embeddings[key].values().cpu().float(), - atol=1, + transformed_graph_embeddings[key].values().cpu().float(), + atol=atol, ) ) + def _test_ec( + self, + tables: List[EmbeddingConfig], + features: KeyedJaggedTensor, + quant_type: torch.dtype = torch.qint8, + output_type: torch.dtype = torch.float, + quant_state_dict_split_scale_bias: bool = False, + ) -> None: + ec = EmbeddingCollection(tables=tables) + if quant_state_dict_split_scale_bias: + quant_prep_enable_quant_state_dict_split_scale_bias(ec) + + embeddings = ec(features) + + # test forward + # pyre-fixme[16]: `EmbeddingCollection` has no attribute `qconfig`. + ec.qconfig = QuantConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=output_type + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type), + per_table_weight_dtype={ + x.name: torch.quint4x2 if x.data_type == DataType.INT4 else torch.qint8 + for x in ec._embedding_configs + }, + ) + + qec = QuantEmbeddingCollection.from_float(ec) + quantized_embeddings = qec(features) + self.assertEqual( + list(quantized_embeddings.values())[0].values().dtype, output_type + ) + self._comp_ec_output(embeddings, quantized_embeddings) + # test state dict - state_dict = eb.state_dict() - quantized_state_dict = qeb.state_dict() + state_dict = ec.state_dict() + quantized_state_dict = ec.state_dict() self.assertEqual(state_dict.keys(), quantized_state_dict.keys()) # pyre-fixme[56] @@ -391,29 +567,118 @@ def _test_ec( DataType.INT8, ] ), + quant_type=st.sampled_from( + [ + torch.half, + torch.qint8, + ] + ), + output_type=st.sampled_from( + [ + torch.half, + torch.float, + ] + ), + quant_state_dict_split_scale_bias=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_ec(self, data_type: DataType) -> None: + def test_ec( + self, + data_type: DataType, + quant_type: torch.dtype, + output_type: torch.dtype, + quant_state_dict_split_scale_bias: bool, + ) -> None: eb1_config = EmbeddingConfig( name="t1", embedding_dim=16, num_embeddings=10, - feature_names=["f1"], + feature_names=["f1", "f2"], data_type=data_type, ) eb2_config = EmbeddingConfig( name="t2", embedding_dim=16, num_embeddings=10, - feature_names=["f2"], + feature_names=["f3", "f4"], data_type=data_type, ) + eb3_config = EmbeddingConfig( + name="t3", + embedding_dim=16, + num_embeddings=10, + feature_names=["f5", "f6"], + data_type=DataType.INT4, + ) features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.as_tensor([0, 1]), - lengths=torch.as_tensor([1, 1]), + keys=["f1", "f2", "f3", "f4", "f5", "f6"], + values=torch.as_tensor( + [ + 5, + 1, + 0, + 0, + 4, + 3, + 4, + 9, + 2, + 2, + 3, + 3, + 1, + 5, + 0, + 7, + 5, + 0, + 9, + 9, + 3, + 5, + 6, + 6, + 9, + 3, + 7, + 8, + 7, + 7, + 9, + 1, + 2, + 6, + 7, + 6, + 1, + 8, + 3, + 8, + 1, + 9, + 7, + 7, + 9, + 1, + 2, + 6, + 7, + 6, + 1, + 8, + 3, + 8, + 1, + 9, + ] + ), + lengths=torch.as_tensor([9, 12, 9, 12, 5, 9]), + ) + self._test_ec( + tables=[eb3_config, eb1_config, eb2_config], + features=features, + quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, ) - self._test_ec([eb1_config, eb2_config], features) def test_shared_tables(self) -> None: eb_config = EmbeddingConfig( @@ -439,3 +704,266 @@ def test_shared_features(self) -> None: lengths=torch.as_tensor([1, 1]), ) self._test_ec([eb1_config, eb2_config], features) + + def test_different_quantization_dtype_per_ec_table(self) -> None: + class TestModule(torch.nn.Module): + def __init__(self, m: torch.nn.Module) -> None: + super().__init__() + self.m = m + + eb1_config = EmbeddingConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingConfig( + name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + ec = EmbeddingCollection(tables=[eb1_config, eb2_config]) + model = TestModule(ec) + qconfig_spec_keys: List[Type[torch.nn.Module]] = [EmbeddingCollection] + quant_mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = { + EmbeddingCollection: QuantEmbeddingCollection + } + trec_infer.modules.quantize_embeddings( + model, + dtype=torch.int8, + additional_qconfig_spec_keys=qconfig_spec_keys, + additional_mapping=quant_mapping, + inplace=True, + per_table_weight_dtype={"t1": torch.float16}, + ) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + configs = model.m.embedding_configs() + self.assertEqual(len(configs), 2) + self.assertNotEqual(configs[0].name, configs[1].name) + for config in configs: + if config.name == "t1": + self.assertEqual(config.data_type, DataType.FP16) + else: + self.assertEqual(config.name, "t2") + self.assertEqual(config.data_type, DataType.INT8) + + def test_different_quantization_dtype_per_ebc_table(self) -> None: + class TestModule(torch.nn.Module): + def __init__(self, m: torch.nn.Module) -> None: + super().__init__() + self.m = m + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + model = TestModule(ebc) + trec_infer.modules.quantize_embeddings( + model, + dtype=torch.int8, + inplace=True, + per_table_weight_dtype={"t1": torch.float16}, + ) + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. + configs = model.m.embedding_bag_configs() + self.assertEqual(len(configs), 2) + self.assertNotEqual(configs[0].name, configs[1].name) + for config in configs: + if config.name == "t1": + self.assertEqual(config.data_type, DataType.FP16) + else: + self.assertEqual(config.name, "t2") + self.assertEqual(config.data_type, DataType.INT8) + + def test_trace_and_script(self) -> None: + data_type = DataType.FP16 + quant_type = torch.half + output_type = torch.half + + ec1_config = EmbeddingConfig( + name="t1", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1", "f2"], + data_type=data_type, + ) + ec2_config = EmbeddingConfig( + name="t2", + embedding_dim=16, + num_embeddings=10, + feature_names=["f3", "f4"], + data_type=data_type, + ) + + ec = EmbeddingCollection(tables=[ec1_config, ec2_config]) + # pyre-fixme[16]: `EmbeddingCollection` has no attribute `qconfig`. + ec.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=output_type + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type), + ) + + qec = QuantEmbeddingCollection.from_float(ec) + + from torchrec.fx import symbolic_trace + + gm = symbolic_trace(qec) + + features = KeyedJaggedTensor( + keys=["f1", "f2", "f3", "f4"], + values=torch.as_tensor([0, 1, 2, 3, 4, 5, 6, 7]), + lengths=torch.as_tensor([1, 2, 3, 2]), + ) + + original_out = qec(features) + traced_out = gm(features) + + scripted_module = torch.jit.script(gm) + scripted_out = scripted_module(features) + self._comp_ec_output(original_out, traced_out, atol=0) + self._comp_ec_output(original_out, scripted_out, atol=0) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs available", + ) + # pyre-fixme[56] + @given( + offsets_dtype=st.sampled_from( + [ + torch.int32, + torch.int64, + ] + ), + indices_dtype=st.sampled_from( + [ + torch.int32, + torch.int64, + ] + ), + device=st.sampled_from( + [ + torch.device("cpu"), + torch.device("cuda"), + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_fx_unwrap_unsharded_vs_sharded_in_sync( + self, + offsets_dtype: torch.dtype, + indices_dtype: torch.dtype, + device: torch.device, + ) -> None: + features = KeyedJaggedTensor( + keys=["f1", "f2", "f3", "f4"], + values=torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], dtype=indices_dtype, device=device + ), + offsets=torch.tensor([0, 2, 5, 7, 8], dtype=offsets_dtype, device=device), + ) + + indices, offsets = _fx_trec_unwrap_kjt(features) + self.assertEqual(indices.dtype, offsets.dtype) + if device.type == "cpu": + sharded_indices, sharded_offsets, _ = _unwrap_kjt_for_cpu( + features, weighted=False + ) + self.assertEqual(sharded_indices.dtype, indices_dtype) + else: # cuda + sharded_indices, sharded_offsets, _ = _unwrap_kjt(features) + self.assertEqual(sharded_indices.dtype, torch.int32) # only option! + + self.assertEqual(indices.dtype, sharded_indices.dtype) + self.assertEqual(offsets.dtype, sharded_offsets.dtype) + + def test_using_flattened_or_unflattened_length_rebatching(self) -> None: + data_type = DataType.FP16 + quant_type = torch.half + output_type = torch.half + + ec1_config = EmbeddingConfig( + name="t1", + embedding_dim=16, + num_embeddings=10, + feature_names=["f1", "f2"], + data_type=data_type, + ) + ec2_config = EmbeddingConfig( + name="t2", + embedding_dim=16, + num_embeddings=10, + feature_names=["f3", "f4"], + data_type=data_type, + ) + + ec = EmbeddingCollection(tables=[ec1_config, ec2_config]) + # pyre-fixme[16]: `EmbeddingCollection` has no attribute `qconfig`. + ec.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=output_type + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type), + ) + + qec = QuantEmbeddingCollection.from_float(ec) + + import copy + + from torchrec.fx import symbolic_trace + + # test using flattened lengths for rebatching (default) + + gm = symbolic_trace(copy.deepcopy(qec)) + + found_get_unflattened_lengths_func = False + + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.name == _get_unflattened_lengths.__name__ + ): + found_get_unflattened_lengths_func = True + for user in node.users: + if ( + user.op == "call_function" + and user.name == _get_batching_hinted_output.__name__ + ): + self.assertTrue( + False, + "Should not call _get_batching_hinted_output after _get_unflattened_lengths", + ) + + self.assertTrue( + found_get_unflattened_lengths_func, + "_get_unflattened_lengths must exist in the graph", + ) + + # test using unflattened lengths for rebatching + + setattr(qec, MODULE_ATTR_USE_UNFLATTENED_LENGTHS_FOR_BATCHING, True) + + gm = symbolic_trace(qec) + + found_get_unflattened_lengths_func = False + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.name == _get_unflattened_lengths.__name__ + ): + found_get_unflattened_lengths_func = True + found_get_batching_hinted_output_func = False + for user in node.users: + if ( + user.op == "call_function" + and user.name == _get_batching_hinted_output.__name__ + ): + found_get_batching_hinted_output_func = True + self.assertTrue( + found_get_batching_hinted_output_func, + "Should call _get_batching_hinted_output after _get_unflattened_lengths", + ) + + self.assertTrue( + found_get_unflattened_lengths_func, + "_get_unflattened_lengths must exist in the graph", + ) diff --git a/torchrec/quant/tests/test_quant_utils.py b/torchrec/quant/tests/test_quant_utils.py index 59c823157..f07d67953 100644 --- a/torchrec/quant/tests/test_quant_utils.py +++ b/torchrec/quant/tests/test_quant_utils.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from typing import List @@ -44,7 +46,7 @@ def _test_meta_to_cpu( ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) # test forward - # pyre-ignore [16] + # pyre-fixme[16]: `EmbeddingBagCollection` has no attribute `qconfig`. ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=output_type diff --git a/torchrec/quant/tests/test_tensor_types.py b/torchrec/quant/tests/test_tensor_types.py new file mode 100644 index 000000000..dc4dd3a4f --- /dev/null +++ b/torchrec/quant/tests/test_tensor_types.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import unittest + +import torch + +from torchrec.tensor_types import UInt2Tensor, UInt4Tensor + + +class QuantUtilsTest(unittest.TestCase): + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs available", + ) + def test_uint42_tensor(self) -> None: + t_u8 = torch.tensor( + [ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], + dtype=torch.uint8, + ) + t_u4 = UInt4Tensor(t_u8) + t_u4.detach() + + t_u4.to(torch.device("cuda")) + assert torch.equal(t_u4.view(torch.uint8), t_u8) + t_u2 = UInt2Tensor(t_u8) + t_u2.to(torch.device("cuda")) + assert torch.equal(t_u2.view(torch.uint8), t_u8) + + for t in [t_u4[:, :8], t_u4[:, 8:]]: + assert t.size(1) == 8 + t_u4[:, :8].copy_(t_u4[:, 8:]) + + for t in [t_u2[:, 4:8], t_u2[:, 8:12]]: + assert t.size(1) == 4 + + t_u2[:, 4:8].copy_(t_u2[:, 8:12]) diff --git a/torchrec/quant/utils.py b/torchrec/quant/utils.py index 5497b9c7c..9161f84d5 100644 --- a/torchrec/quant/utils.py +++ b/torchrec/quant/utils.py @@ -5,31 +5,109 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + + +from typing import Optional, Union import torch from torch import nn from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, ) -def meta_to_cpu_placement(module: DistributedModelParallel) -> None: - assert hasattr(module, "_dmp_wrapped_module") - _meta_to_cpu_placement(module.module, module, "_dmp_wrapped_module") +def populate_fx_names( + quant_ebc: Union[QuantEmbeddingBagCollection, ShardedQuantEmbeddingBagCollection] +) -> None: + """ + Assigns fx path to non registered lookup modules. This allows the Torchrec tracer to fallback to + emb_module._fx_path for table batched embeddings. + """ + if isinstance(quant_ebc, QuantEmbeddingBagCollection): + for emb_configs, emb_module in zip( + quant_ebc._key_to_tables, quant_ebc._emb_modules + ): + table_names = [] + for config in emb_configs: + table_names.append(config.name) + joined_table_names = ",".join(table_names) + # pyre-fixme[16]: `Module` has no attribute `_fx_path`. + emb_module._fx_path = f"emb_module.{joined_table_names}" + elif isinstance(quant_ebc, ShardedQuantEmbeddingBagCollection): + for i, (emb_module, emb_dist_module) in enumerate( + zip(quant_ebc._lookups, quant_ebc._output_dists) + ): + embedding_fx_path = f"embedding_lookup.sharding_{i}" + emb_module._fx_path = embedding_fx_path + emb_dist_module._fx_path = f"embedding_dist.{i}" + # pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got + # `Union[Module, Tensor]`. + for rank, rank_module in enumerate(emb_module._embedding_lookups_per_rank): + rank_fx_path = f"{embedding_fx_path}.rank_{rank}" + rank_module._fx_path = rank_fx_path + for group, group_module in enumerate(rank_module._emb_modules): + group_module._fx_path = f"{rank_fx_path}.group_{group}" + group_module._emb_module._fx_path = ( + f"{rank_fx_path}.group_{group}.tbe" + ) + + +def recursive_populate_fx_names(module: nn.Module) -> None: + if isinstance(module, QuantEmbeddingBagCollection) or isinstance( + module, ShardedQuantEmbeddingBagCollection + ): + populate_fx_names(module) + return + for submodule in module.children(): + recursive_populate_fx_names(submodule) + + +def meta_to_cpu_placement(module: torch.nn.Module) -> None: + if hasattr(module, "_dmp_wrapped_module"): + # for placement update of dmp module, we need to fetch .module (read access) and write + # to .dmp_wrapped_module (write access) + assert type(module) == DistributedModelParallel + _meta_to_cpu_placement(module.module, module, "_dmp_wrapped_module") + else: + # shard module case + _meta_to_cpu_placement(module, module) def _meta_to_cpu_placement( - module: nn.Module, root_module: nn.Module, name: str + module: nn.Module, root_module: nn.Module, name: Optional[str] = None ) -> None: - if isinstance(module, QuantEmbeddingBagCollection) and module.device.type == "meta": + if ( + name is not None + and isinstance(module, QuantEmbeddingBagCollection) + and module.device.type == "meta" + ): qebc_cpu = QuantEmbeddingBagCollection( tables=module.embedding_bag_configs(), is_weighted=module.is_weighted(), device=torch.device("cpu"), output_dtype=module.output_dtype(), + register_tbes=module.register_tbes, + row_alignment=module.row_alignment, ) setattr(root_module, name, qebc_cpu) - return - for name, submodule in module.named_children(): - _meta_to_cpu_placement(submodule, module, name) + elif ( + name is not None + and isinstance(module, QuantEmbeddingCollection) + and module.device.type == "meta" + ): + qec_cpu = QuantEmbeddingCollection( + tables=module.embedding_configs(), + device=torch.device("cpu"), + need_indices=module.need_indices(), + output_dtype=module.output_dtype(), + register_tbes=module.register_tbes, + row_alignment=module.row_alignment, + ) + setattr(root_module, name, qec_cpu) + else: + for name, submodule in module.named_children(): + _meta_to_cpu_placement(submodule, module, name) diff --git a/torchrec/schema/__init__.py b/torchrec/schema/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/schema/api_tests/test_embedding_config_schema.py b/torchrec/schema/api_tests/test_embedding_config_schema.py new file mode 100644 index 000000000..c0ca41a5b --- /dev/null +++ b/torchrec/schema/api_tests/test_embedding_config_schema.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +import torch +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingBagConfig, + EmbeddingConfig, + PoolingType, +) + +from torchrec.schema.utils import is_signature_compatible + + +@dataclass +class StableEmbeddingBagConfig: + num_embeddings: int + embedding_dim: int + name: str = "" + data_type: DataType = DataType.FP32 + feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + num_embeddings_post_pruning: Optional[int] = None + + init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None + # when the position_weighted feature is in this table config, + # enable this flag to support rw_sharding + need_pos: bool = False + input_dim: Optional[int] = None + pooling: PoolingType = PoolingType.SUM + + +@dataclass +class StableEmbeddingConfig: + num_embeddings: int + embedding_dim: int + name: str = "" + data_type: DataType = DataType.FP32 + feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + num_embeddings_post_pruning: Optional[int] = None + + init_fn: Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]] = None + # when the position_weighted feature is in this table config, + # enable this flag to support rw_sharding + need_pos: bool = False + input_dim: Optional[int] = None + + +class TestEmbeddingConfigSchema(unittest.TestCase): + def test_embedding_bag_config(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagConfig.__init__), + inspect.signature(EmbeddingBagConfig.__init__), + ) + ) + + def test_embedding_config(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingConfig.__init__), + inspect.signature(EmbeddingConfig.__init__), + ) + ) diff --git a/torchrec/schema/api_tests/test_embedding_module_schema.py b/torchrec/schema/api_tests/test_embedding_module_schema.py new file mode 100644 index 000000000..528ac082e --- /dev/null +++ b/torchrec/schema/api_tests/test_embedding_module_schema.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Dict, List, Optional + +import torch +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) + +from torchrec.schema.utils import is_signature_compatible +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +class StableEmbeddingBagCollectionInterface: + """ + Stable Interface for `EmbeddingBagCollection`. + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool = False, + device: Optional[torch.device] = None, + ) -> None: + pass + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return KeyedTensor( + keys=[], + length_per_key=[], + values=torch.empty(0), + ) + + def embedding_bag_configs( + self, + ) -> List[EmbeddingBagConfig]: + return [] + + def is_weighted(self) -> bool: + return False + + +class StableEmbeddingCollectionInterface: + """ + Stable Interface for `EmbeddingBagCollection`. + """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: Optional[torch.device] = None, + need_indices: bool = False, + ) -> None: + return + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, JaggedTensor]: + return {} + + def embedding_configs( + self, + ) -> List[EmbeddingConfig]: + return [] + + def need_indices(self) -> bool: + return False + + def embedding_dim(self) -> int: + return 0 + + def embedding_names_by_table(self) -> List[List[str]]: + return [] + + +class TestEmbeddingModuleSchema(unittest.TestCase): + def test_embedding_bag_collection(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.__init__), + inspect.signature(EmbeddingBagCollection.__init__), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.forward), + inspect.signature(EmbeddingBagCollection.forward), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature( + StableEmbeddingBagCollectionInterface.embedding_bag_configs + ), + inspect.signature(EmbeddingBagCollection.embedding_bag_configs), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingBagCollectionInterface.is_weighted), + inspect.signature(EmbeddingBagCollection.is_weighted), + ) + ) + + def test_embedding_collection(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.__init__), + inspect.signature(EmbeddingCollection.__init__), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.forward), + inspect.signature(EmbeddingCollection.forward), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.embedding_configs), + inspect.signature(EmbeddingCollection.embedding_configs), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature(StableEmbeddingCollectionInterface.embedding_dim), + inspect.signature(EmbeddingCollection.embedding_dim), + ) + ) + + self.assertTrue( + is_signature_compatible( + inspect.signature( + StableEmbeddingCollectionInterface.embedding_names_by_table + ), + inspect.signature(EmbeddingCollection.embedding_names_by_table), + ) + ) diff --git a/torchrec/schema/api_tests/test_inference_schema.py b/torchrec/schema/api_tests/test_inference_schema.py new file mode 100644 index 000000000..abbcc2039 --- /dev/null +++ b/torchrec/schema/api_tests/test_inference_schema.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +import inspect +import unittest +from typing import Any, cast, Dict, List, Optional, Tuple, Type + +import torch +from torchrec.distributed.fused_params import ( + FUSED_PARAM_BOUNDS_CHECK_MODE, + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + FUSED_PARAM_REGISTER_TBE_BOOL, +) +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollection, + QuantEmbeddingCollectionSharder, +) +from torchrec.distributed.quant_embeddingbag import ( + QuantEmbeddingBagCollection, + QuantEmbeddingBagCollectionSharder, + QuantFeatureProcessedEmbeddingBagCollectionSharder, +) +from torchrec.distributed.types import BoundsCheckMode, ModuleSharder, ShardingPlan +from torchrec.inference.modules import ( + DEFAULT_FUSED_PARAMS, + DEFAULT_QUANT_MAPPING, + DEFAULT_QUANTIZATION_DTYPE, + DEFAULT_SHARDERS, + quantize_inference_model, + shard_quant_model, + trim_torch_package_prefix_from_typename, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.schema.utils import is_signature_compatible + +STABLE_DEFAULT_QUANTIZATION_DTYPE: torch.dtype = torch.int8 + + +STABLE_DEFAULT_FUSED_PARAMS: Dict[str, Any] = { + FUSED_PARAM_REGISTER_TBE_BOOL: True, + FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True, + FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, +} + +STABLE_DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [ + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingBagCollectionSharder(fused_params=STABLE_DEFAULT_FUSED_PARAMS), + ), + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingCollectionSharder(fused_params=STABLE_DEFAULT_FUSED_PARAMS), + ), + cast( + ModuleSharder[torch.nn.Module], + QuantFeatureProcessedEmbeddingBagCollectionSharder( + fused_params=STABLE_DEFAULT_FUSED_PARAMS + ), + ), +] + +STABLE_DEFAULT_QUANT_MAPPING: Dict[str, Type[torch.nn.Module]] = { + trim_torch_package_prefix_from_typename( + torch.typename(EmbeddingBagCollection) + ): QuantEmbeddingBagCollection, + trim_torch_package_prefix_from_typename( + torch.typename(EmbeddingCollection) + ): QuantEmbeddingCollection, +} + + +def stable_quantize_inference_model( + model: torch.nn.Module, + quantization_mapping: Optional[Dict[str, Type[torch.nn.Module]]] = None, + per_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, + fp_weight_dtype: torch.dtype = STABLE_DEFAULT_QUANTIZATION_DTYPE, + quantization_dtype: torch.dtype = STABLE_DEFAULT_QUANTIZATION_DTYPE, + output_dtype: torch.dtype = torch.float, +) -> torch.nn.Module: + return model + + +def stable_shard_quant_model( + model: torch.nn.Module, + world_size: int = 1, + compute_device: str = "cuda", + sharding_device: str = "meta", + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + device_memory_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, +) -> Tuple[torch.nn.Module, ShardingPlan]: + return (model, ShardingPlan(plan={})) + + +class TestInferenceSchema(unittest.TestCase): + def test_quantize_inference_model(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_quantize_inference_model), + inspect.signature(quantize_inference_model), + ) + ) + + def test_shard_quant_model(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_shard_quant_model), + inspect.signature(shard_quant_model), + ) + ) + + def test_default_mappings(self) -> None: + # check that the default mappings are a superset of the stable ones + for ( + name, + module_type, + ) in STABLE_DEFAULT_QUANT_MAPPING.items(): + self.assertTrue(name in DEFAULT_QUANT_MAPPING) + self.assertTrue(DEFAULT_QUANT_MAPPING[name] == module_type) + + # check that the fused params are a superset of the stable ones + for ( + name, + val, + ) in STABLE_DEFAULT_FUSED_PARAMS.items(): + self.assertTrue(name in DEFAULT_FUSED_PARAMS) + self.assertTrue(DEFAULT_FUSED_PARAMS[name] == val) + + # Check default quant type + self.assertTrue(DEFAULT_QUANTIZATION_DTYPE == STABLE_DEFAULT_QUANTIZATION_DTYPE) + + # Check default sharders are a superset of the stable ones + # and check fused_params are also a superset + for sharder in STABLE_DEFAULT_SHARDERS: + found = False + for default_sharder in DEFAULT_SHARDERS: + if isinstance(default_sharder, type(sharder)): + # pyre-ignore[16] + for key in sharder.fused_params.keys(): + self.assertTrue(key in default_sharder.fused_params) + self.assertTrue( + default_sharder.fused_params[key] + == sharder.fused_params[key] + ) + found = True + + self.assertTrue(found) diff --git a/torchrec/schema/api_tests/test_jagged_tensor_schema.py b/torchrec/schema/api_tests/test_jagged_tensor_schema.py new file mode 100644 index 000000000..eacb10d9e --- /dev/null +++ b/torchrec/schema/api_tests/test_jagged_tensor_schema.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Dict, List, Optional, Tuple + +import torch +from torchrec.schema.utils import is_signature_compatible +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + + +class StableJaggedTensor: + def __init__( + self, + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> None: + pass + + @staticmethod + def empty( + is_weighted: bool = False, + device: Optional[torch.device] = None, + values_dtype: Optional[torch.dtype] = None, + weights_dtype: Optional[torch.dtype] = None, + lengths_dtype: torch.dtype = torch.int32, + ) -> "JaggedTensor": + return JaggedTensor(torch.empty(0)) + + @staticmethod + def from_dense_lengths( + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + ) -> "JaggedTensor": + return JaggedTensor(torch.empty(0)) + + @staticmethod + def from_dense( + values: List[torch.Tensor], + weights: Optional[List[torch.Tensor]] = None, + ) -> "JaggedTensor": + return JaggedTensor(torch.empty(0)) + + def to_dense(self) -> List[torch.Tensor]: + return [] + + def to_dense_weights(self) -> Optional[List[torch.Tensor]]: + pass + + def to_padded_dense( + self, + desired_length: Optional[int] = None, + padding_value: float = 0.0, + ) -> torch.Tensor: + return torch.empty(0) + + def to_padded_dense_weights( + self, + desired_length: Optional[int] = None, + padding_value: float = 0.0, + ) -> Optional[torch.Tensor]: + pass + + def device(self) -> torch.device: + return torch.device("cpu") + + def lengths(self) -> torch.Tensor: + return torch.empty(0) + + def lengths_or_none(self) -> Optional[torch.Tensor]: + pass + + def offsets(self) -> torch.Tensor: + return torch.empty(0) + + def offsets_or_none(self) -> Optional[torch.Tensor]: + pass + + def values(self) -> torch.Tensor: + return torch.empty(0) + + def weights(self) -> torch.Tensor: + return torch.empty(0) + + def weights_or_none(self) -> Optional[torch.Tensor]: + pass + + def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor": + return JaggedTensor(torch.empty(0)) + + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + pass + + +class StableKeyedJaggedTensor: + def __init__( + self, + keys: List[str], + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + # Below exposed to ensure torch.script-able + stride_per_key: Optional[List[int]] = None, + length_per_key: Optional[List[int]] = None, + lengths_offset_per_key: Optional[List[int]] = None, + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + jt_dict: Optional[Dict[str, JaggedTensor]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + ) -> None: + pass + + @staticmethod + def from_offsets_sync( + keys: List[str], + values: torch.Tensor, + offsets: torch.Tensor, + weights: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + @staticmethod + def from_lengths_sync( + keys: List[str], + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + @staticmethod + def concat( + kjt_list: List["KeyedJaggedTensor"], + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + @staticmethod + def empty( + is_weighted: bool = False, + device: Optional[torch.device] = None, + values_dtype: Optional[torch.dtype] = None, + weights_dtype: Optional[torch.dtype] = None, + lengths_dtype: torch.dtype = torch.int32, + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + @staticmethod + def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + @staticmethod + def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def sync(self) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def unsync(self) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def device(self) -> torch.device: + return torch.device("cpu") + + def lengths(self) -> torch.Tensor: + return torch.empty(0) + + def lengths_or_none(self) -> Optional[torch.Tensor]: + pass + + def offsets(self) -> torch.Tensor: + return torch.empty(0) + + def offsets_or_none(self) -> Optional[torch.Tensor]: + pass + + def keys(self) -> List[str]: + return [] + + def values(self) -> torch.Tensor: + return torch.empty(0) + + def weights(self) -> torch.Tensor: + return torch.empty(0) + + def weights_or_none(self) -> Optional[torch.Tensor]: + pass + + def stride(self) -> int: + return 0 + + def stride_per_key(self) -> List[int]: + return [] + + def stride_per_key_per_rank(self) -> List[List[int]]: + return [] + + def variable_stride_per_key(self) -> bool: + return False + + def inverse_indices(self) -> Tuple[List[str], torch.Tensor]: + return ([], torch.empty(0)) + + def inverse_indices_or_none(self) -> Optional[Tuple[List[str], torch.Tensor]]: + pass + + def _key_indices(self) -> Dict[str, int]: + return {} + + def length_per_key(self) -> List[int]: + return [] + + def length_per_key_or_none(self) -> Optional[List[int]]: + pass + + def offset_per_key(self) -> List[int]: + return [] + + def offset_per_key_or_none(self) -> Optional[List[int]]: + pass + + def lengths_offset_per_key(self) -> List[int]: + return [] + + def index_per_key(self) -> Dict[str, int]: + return {} + + def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: + return [] + + def permute( + self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def flatten_lengths(self) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def __getitem__(self, key: str) -> JaggedTensor: + return JaggedTensor(torch.empty(0)) + + def to_dict(self) -> Dict[str, JaggedTensor]: + return {} + + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + pass + + def to( + self, + device: torch.device, + non_blocking: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def pin_memory(self) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + def dist_labels(self) -> List[str]: + return [] + + def dist_splits(self, key_splits: List[int]) -> List[List[int]]: + return [] + + def dist_tensors(self) -> List[torch.Tensor]: + return [] + + @staticmethod + def dist_init( + keys: List[str], + tensors: List[torch.Tensor], + variable_stride_per_key: bool, + num_workers: int, + recat: Optional[torch.Tensor], + stride_per_rank: Optional[List[int]], + stagger: int = 1, + ) -> "KeyedJaggedTensor": + return KeyedJaggedTensor([], torch.empty(0)) + + +class StableKeyedTensor: + def __init__( + self, + keys: List[str], + length_per_key: List[int], + values: torch.Tensor, + key_dim: int = 1, + # Below exposed to ensure torch.script-able + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + ) -> None: + pass + + @staticmethod + def from_tensor_list( + keys: List[str], tensors: List[torch.Tensor], key_dim: int = 1, cat_dim: int = 1 + ) -> "KeyedTensor": + return KeyedTensor([], [], torch.empty(0)) + + def keys(self) -> List[str]: + return [] + + def values(self) -> torch.Tensor: + return torch.empty(0) + + def key_dim(self) -> int: + return 0 + + def device(self) -> torch.device: + return torch.device("cpu") + + def offset_per_key(self) -> List[int]: + return [] + + def length_per_key(self) -> List[int]: + return [] + + def _key_indices(self) -> Dict[str, int]: + return {} + + def __getitem__(self, key: str) -> torch.Tensor: + return torch.empty(0) + + def to_dict(self) -> Dict[str, torch.Tensor]: + return {} + + @staticmethod + def regroup( + keyed_tensors: List["KeyedTensor"], groups: List[List[str]] + ) -> List[torch.Tensor]: + return [] + + @staticmethod + def regroup_as_dict( + keyed_tensors: List["KeyedTensor"], groups: List[List[str]], keys: List[str] + ) -> Dict[str, torch.Tensor]: + return {} + + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + pass + + def to(self, device: torch.device, non_blocking: bool = False) -> "KeyedTensor": + return KeyedTensor([], [], torch.empty(0)) + + +class TestJaggedTensorSchema(unittest.TestCase): + def test_kjt(self) -> None: + stable_kjt_funcs = inspect.getmembers( + StableKeyedJaggedTensor, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_kjt_funcs: + self.assertTrue(getattr(KeyedJaggedTensor, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(KeyedJaggedTensor, func_name)), + ) + ) + + def test_jt(self) -> None: + stable_jt_funcs = inspect.getmembers( + StableJaggedTensor, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_jt_funcs: + self.assertTrue(getattr(JaggedTensor, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(JaggedTensor, func_name)), + ) + ) + + def test_kt(self) -> None: + stable_kt_funcs = inspect.getmembers( + StableKeyedTensor, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_kt_funcs: + self.assertTrue(getattr(KeyedTensor, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(KeyedTensor, func_name)), + ) + ) diff --git a/torchrec/schema/api_tests/test_model_parallel_schema.py b/torchrec/schema/api_tests/test_model_parallel_schema.py new file mode 100644 index 000000000..ee5b92512 --- /dev/null +++ b/torchrec/schema/api_tests/test_model_parallel_schema.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Any, List, Optional + +import torch +from torch import nn + +from torchrec.distributed.model_parallel import ( + DataParallelWrapper, + DistributedModelParallel, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.schema.utils import is_signature_compatible + + +def stable_dmp_init( + # pyre-ignore [2] + self, + module: nn.Module, + env: Optional[ShardingEnv] = None, + device: Optional[torch.device] = None, + plan: Optional[ShardingPlan] = None, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_data_parallel: bool = True, + init_parameters: bool = True, + data_parallel_wrapper: Optional[DataParallelWrapper] = None, +) -> None: + pass + + +# pyre-ignore [3] +def stable_dmp_forward( + # pyre-ignore [2] + self, + # pyre-ignore [2] + *args, + # pyre-ignore [2] + **kwargs, +) -> Any: + pass + + +class TestModelParallelSchema(unittest.TestCase): + def test_dmp_init(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_dmp_init), + inspect.signature(DistributedModelParallel.__init__), + ) + ) + + def test_dmp_forward(self) -> None: + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_dmp_forward), + inspect.signature(DistributedModelParallel.forward), + ) + ) diff --git a/torchrec/schema/api_tests/test_optimizer_schema.py b/torchrec/schema/api_tests/test_optimizer_schema.py new file mode 100644 index 000000000..a204c67e7 --- /dev/null +++ b/torchrec/schema/api_tests/test_optimizer_schema.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Any, Collection, List, Mapping, Optional, Set, Tuple, Union + +import torch +from torch import optim + +from torchrec.distributed.types import ShardedTensor +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer +from torchrec.schema.utils import is_signature_compatible + + +class StableKeyedOptimizer(optim.Optimizer): + def __init__( + self, + params: Mapping[str, Union[torch.Tensor, ShardedTensor]], + # pyre-ignore [2] + state: Mapping[Any, Any], + param_groups: Collection[Mapping[str, Any]], + ) -> None: + pass + + def init_state( + self, + sparse_grad_parameter_names: Optional[Set[str]] = None, + ) -> None: + pass + + def save_param_groups(self, save: bool) -> None: + pass + + # pyre-ignore [2] + def add_param_group(self, param_group: Any) -> None: + pass + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + pass + + +class StableCombinedOptimizer(KeyedOptimizer): + def __init__( + self, + optims: List[Union[KeyedOptimizer, Tuple[str, KeyedOptimizer]]], + ) -> None: + pass + + @property + def optimizers(self) -> List[Tuple[str, StableKeyedOptimizer]]: + return [] + + @staticmethod + def prepend_opt_key(name: str, opt_key: str) -> str: + return "" + + @property + def param_groups(self) -> Collection[Mapping[str, Any]]: + return [] + + @property + def params(self) -> Mapping[str, Union[torch.Tensor, ShardedTensor]]: + return {} + + def post_load_state_dict(self) -> None: + pass + + def save_param_groups(self, save: bool) -> None: + pass + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + pass + + def zero_grad(self, set_to_none: bool = False) -> None: + pass + + +class TestOptimizerSchema(unittest.TestCase): + def test_keyed_optimizer(self) -> None: + stable_keyed_optimizer_funcs = inspect.getmembers( + StableKeyedOptimizer, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_keyed_optimizer_funcs: + self.assertTrue(getattr(KeyedOptimizer, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(KeyedOptimizer, func_name)), + ) + ) + + def test_combined_optimizer(self) -> None: + stable_combined_optimizer_funcs = inspect.getmembers( + StableCombinedOptimizer, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_combined_optimizer_funcs: + self.assertTrue(getattr(CombinedOptimizer, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(CombinedOptimizer, func_name)), + ) + ) diff --git a/torchrec/schema/api_tests/test_planner_schema.py b/torchrec/schema/api_tests/test_planner_schema.py new file mode 100644 index 000000000..cd1dcd7c5 --- /dev/null +++ b/torchrec/schema/api_tests/test_planner_schema.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Dict, List, Optional, Union + +import torch.distributed as dist +from torch import nn +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner, SortBy +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner +from torchrec.distributed.planner.proposers import GreedyProposer +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.planner.storage_reservations import ( + HeuristicalStorageReservation, +) +from torchrec.distributed.planner.types import ( + Enumerator, + ParameterConstraints, + Partitioner, + PerfModel, + Proposer, + ShardEstimator, + ShardingOption, + ShardingPlan, + Stats, + StorageReservation, + Topology, +) +from torchrec.distributed.types import ModuleSharder, PipelineType, ShardingPlan +from torchrec.schema.utils import is_signature_compatible + + +class StableEmbeddingShardingPlanner: + def __init__( + self, + topology: Optional[Topology] = None, + batch_size: Optional[int] = None, + enumerator: Optional[Enumerator] = None, + storage_reservation: Optional[StorageReservation] = None, + proposer: Optional[Union[Proposer, List[Proposer]]] = None, + partitioner: Optional[Partitioner] = None, + performance_model: Optional[PerfModel] = None, + stats: Optional[Union[Stats, List[Stats]]] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + debug: bool = True, + ) -> None: + pass + + def collective_plan( + self, + module: nn.Module, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + pg: Optional[dist.ProcessGroup] = None, + ) -> ShardingPlan: + return ShardingPlan(plan={}) + + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + return ShardingPlan(plan={}) + + +class StableEmbeddingEnumerator: + def __init__( + self, + topology: Topology, + batch_size: int, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, + use_exact_enumerate_order: Optional[bool] = False, + ) -> None: + pass + + def enumerate( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> List[ShardingOption]: + return [] + + def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: + pass + + +class StableGreedyPerfPartitioner: + def __init__( + self, sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False + ) -> None: + pass + + def partition( + self, + proposal: List[ShardingOption], + storage_constraint: Topology, + ) -> List[ShardingOption]: + return [] + + +class StableHeuristicalStorageReservation: + def __init__( + self, + percentage: float, + parameter_multiplier: float = 6.0, + dense_tensor_estimate: Optional[int] = None, + ) -> None: + pass + + def reserve( + self, + topology: Topology, + batch_size: int, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + constraints: Optional[Dict[str, ParameterConstraints]] = None, + ) -> Topology: + return Topology(world_size=0, compute_device="cuda") + + +class StableGreedyProposer: + def __init__(self, use_depth: bool = True, threshold: Optional[int] = None) -> None: + pass + + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: + pass + + def propose(self) -> Optional[List[ShardingOption]]: + return [] + + def feedback( + self, + partitionable: bool, + plan: Optional[List[ShardingOption]] = None, + perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, + ) -> None: + pass + + +class StableEmbeddingPerfEstimator: + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + is_inference: bool = False, + ) -> None: + pass + + def estimate( + self, + sharding_options: List[ShardingOption], + sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None, + ) -> None: + pass + + +class StableEmbeddingStorageEstimator: + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + pipeline_type: PipelineType = PipelineType.NONE, + run_embedding_at_peak_memory: bool = False, + is_inference: bool = False, + ) -> None: + pass + + def estimate( + self, + sharding_options: List[ShardingOption], + sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None, + ) -> None: + pass + + +class TestPlanner(unittest.TestCase): + def test_planner(self) -> None: + stable_planner_funcs = inspect.getmembers( + StableEmbeddingShardingPlanner, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_planner_funcs: + self.assertTrue( + getattr(EmbeddingShardingPlanner, func_name, None) is not None + ) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(EmbeddingShardingPlanner, func_name)), + ) + ) + + def test_enumerator(self) -> None: + stable_enumerator_funcs = inspect.getmembers( + StableEmbeddingEnumerator, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_enumerator_funcs: + self.assertTrue(getattr(EmbeddingEnumerator, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(EmbeddingEnumerator, func_name)), + ) + ) + + def test_partitioner(self) -> None: + stable_partitioner_funcs = inspect.getmembers( + StableGreedyPerfPartitioner, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_partitioner_funcs: + self.assertTrue(getattr(GreedyPerfPartitioner, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(GreedyPerfPartitioner, func_name)), + ) + ) + + def test_storage_reservation(self) -> None: + stable_storage_reservation_funcs = inspect.getmembers( + StableHeuristicalStorageReservation, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_storage_reservation_funcs: + self.assertTrue( + getattr(HeuristicalStorageReservation, func_name, None) is not None + ) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature( + getattr(HeuristicalStorageReservation, func_name) + ), + ) + ) + + def test_proposer(self) -> None: + stable_proposer_funcs = inspect.getmembers( + StableGreedyProposer, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_proposer_funcs: + self.assertTrue(getattr(GreedyProposer, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(GreedyProposer, func_name)), + ) + ) + + def test_perf_estimator(self) -> None: + stable_perf_estimator_funcs = inspect.getmembers( + StableEmbeddingPerfEstimator, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_perf_estimator_funcs: + self.assertTrue( + getattr(EmbeddingPerfEstimator, func_name, None) is not None + ) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(EmbeddingPerfEstimator, func_name)), + ) + ) + + def test_storage_estimator(self) -> None: + stable_storage_estimator_funcs = inspect.getmembers( + StableEmbeddingStorageEstimator, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_storage_estimator_funcs: + self.assertTrue( + getattr(EmbeddingStorageEstimator, func_name, None) is not None + ) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(EmbeddingStorageEstimator, func_name)), + ) + ) diff --git a/torchrec/schema/test_schema_utils.py b/torchrec/schema/test_schema_utils.py new file mode 100644 index 000000000..61fe5ba59 --- /dev/null +++ b/torchrec/schema/test_schema_utils.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect + +import unittest +from typing import Any, Dict, Tuple + +from .utils import is_signature_compatible + + +def stable_test_func( + a: int, b: float, *, c: int, d: float = 1.0, **kwargs: Dict[str, Any] +) -> int: + return a + + +def stable_test_func_basic( + a: int, + b: float, + c: bool = True, + d: float = 1.0, +) -> bool: + return c + + +class TestUtils(unittest.TestCase): + def test_is_not_backwards_compatible(self) -> None: + ## stable_test_func tests + def test_func_positional_arg_removed( + a: int, *, c: int, d: float = 1.0, **kwargs: Dict[str, Any] + ) -> int: + return a + + def test_func_positional_arg_added( + a: int, + b: float, + z: float, + *, + c: int, + d: float = 1.0, + **kwargs: Dict[str, Any], + ) -> int: + return a + + def test_func_keyword_arg_removed( + a: int, b: float, *, d: float = 1.0, **kwargs: Dict[str, Any] + ) -> int: + return a + + def test_func_var_kwargs_removed( + a: int, b: float, z: float, *, d: float = 1.0 + ) -> int: + return a + + def test_func_var_args_removed( + a: int, b: float, z: float, d: float = 1.0, **kwargs: Dict[str, Any] + ) -> int: + return a + + # stable_test_func_basic tests + def test_func_basic_keyword_or_pos_arg_shifted( + a: int, + b: float, + d: float = 1.0, + c: bool = True, + ) -> bool: + return c + + def test_func_basic_add_arg_in_middle( + a: int, + b: float, + d: float = 1.0, + z: float = 1.0, + c: bool = True, + ) -> bool: + return c + + def test_func_basic_default_arg_changed( + a: int, + b: float, + c: bool = True, + d: float = 2.0, + ) -> bool: + return c + + def test_func_basic_default_arg_removed( + a: int, + b: float, + c: bool, + d: float = 1.0, + ) -> bool: + return c + + def test_func_basic_arg_type_change( + a: int, + b: bool, + c: bool = True, + d: float = 1.0, + ) -> bool: + return c + + def test_func_basic_return_type_changed( + a: int, + b: float, + c: bool = True, + d: float = 1.0, + ) -> int: + return a + + local_funcs = locals() + for name, func in local_funcs.items(): + if name.startswith("test_func_basic"): + self.assertFalse( + is_signature_compatible( + inspect.signature(stable_test_func_basic), + inspect.signature(func), + ), + f"{name} is backwards compatible with stable_test_func_basic when it shouldn't be.", + ) + elif name.startswith("test_func"): + self.assertFalse( + is_signature_compatible( + inspect.signature(stable_test_func), inspect.signature(func) + ), + f"{name} is not backwards compatible with stable_test_func when it shouldn't be.", + ) + else: + continue + + def test_is_backwards_compatible(self) -> None: + # stable_test_func tests + def test_func_keyword_arg_added( + a: int, + b: float, + *, + c: int, + d: float = 1.0, + e: float = 1.0, + **kwargs: Dict[str, Any], + ) -> int: + return a + + def test_func_keyword_arg_added_in_middle( + a: int, + b: float, + *, + c: int, + e: float = 1.0, + d: float = 1.0, + **kwargs: Dict[str, Any], + ) -> int: + return a + + def test_func_keyword_arg_shifted( + a: int, b: float, *, d: float = 1.0, c: int, **kwargs: Dict[str, Any] + ) -> int: + return a + + # stable_test_func_basic tests + def test_func_basic_add_arg_at_end( + a: int, + b: float, + c: bool = True, + d: float = 1.0, + e: float = 1.0, + ) -> bool: + return c + + def test_func_basic_add_var_args_at_end( + a: int, + b: float, + c: bool = True, + d: float = 1.0, + *args: Tuple[Any], + ) -> bool: + return c + + def test_func_basic_add_var_kwargs_at_end( + a: int, + b: float, + c: bool = True, + d: float = 1.0, + **kwargs: Dict[str, Any], + ) -> bool: + return c + + local_funcs = locals() + for name, func in local_funcs.items(): + if name.startswith("test_func_basic"): + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_test_func_basic), + inspect.signature(func), + ), + f"{name} is supposed to be backwards compatible with stable_test_func_basic", + ) + elif name.startswith("test_func"): + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_test_func), inspect.signature(func) + ), + f"{name} is supposed to be backwards compatible with stable_test_func", + ) + else: + continue diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py new file mode 100644 index 000000000..b4f8a6075 --- /dev/null +++ b/torchrec/schema/utils.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: object, curr: object) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True + + +def is_signature_compatible( + previous_signature: inspect.Signature, + current_signature: inspect.Signature, +) -> bool: + """Check if two signatures are compatible. + + Args: + sig1: The first signature. + sig2: The second signature. + + Returns: + True if the signatures are compatible, False otherwise. + + """ + + # If current signature has less parameters than expected signature + # BC is automatically broken, no need to check further + if len(previous_signature.parameters) > len(current_signature.parameters): + return False + + # Check order of positional arguments + expected_args = list(previous_signature.parameters.values()) + current_args = list(current_signature.parameters.values()) + + # Store the names of all keyword only arguments + # to check if all expected keyword only arguments + # are present in current signature + expected_keyword_only_args = set() + current_keyword_only_args = set() + + expected_args_len = len(expected_args) + + for i in range(len(current_args)): + current_arg = current_args[i] + if current_arg.kind == current_arg.KEYWORD_ONLY: + current_keyword_only_args.add(current_arg.name) + + if i >= expected_args_len: + continue + + expected_arg = expected_args[i] + + # If the kinds of arguments are different, BC is broken + # unless current arg is a keyword argument + if expected_arg.kind != current_arg.kind: + if expected_arg.kind == expected_arg.VAR_KEYWORD: + # Any arg can be inserted before **kwargs and still maintain BC + continue + else: + return False + + # Potential positional arguments need to have the same name + # keyword only arguments can be mixed up + if expected_arg.kind == expected_arg.POSITIONAL_OR_KEYWORD: + if expected_arg.name != current_arg.name: + return False + + # Positional arguments need to have the same type annotation + # TODO: Account for Union Types? + if expected_arg.annotation != current_arg.annotation: + return False + + # Positional arguments need to have the same default value + if expected_arg.default != current_arg.default: + return False + elif expected_arg.kind == expected_arg.KEYWORD_ONLY: + expected_keyword_only_args.add(expected_arg.name) + + # All kwargs in expected signature must be present in current signature + for kwarg in expected_keyword_only_args: + if kwarg not in current_keyword_only_args: + return False + + # TODO: Account for Union Types? + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): + return False + return True diff --git a/torchrec/sparse/__init__.py b/torchrec/sparse/__init__.py index 3e3232d5d..a3aa92962 100644 --- a/torchrec/sparse/__init__.py +++ b/torchrec/sparse/__init__.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + """Torchrec Jagged Tensors It has 3 classes: JaggedTensor, KeyedJaggedTensor, KeyedTensor. diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 988a7f6e9..2bbe09149 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -5,27 +5,71 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc -from typing import Dict, List, Optional, Tuple +import logging + +import operator + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.fx +from torch.autograd.profiler import record_function +from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec +from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node +from torchrec.pt2.checks import ( + is_non_strict_exporting, + is_pt2_compiling, + is_torchdynamo_compiling, + pt2_check_size_nonzero, + pt2_checks_all_is_size, + pt2_checks_tensor_slice, + pt2_guard_size_oblivious, +) from torchrec.streamable import Pipelineable try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: pass -# OSS -try: - import fbgemm_gpu # @manual # noqa -except ImportError: - pass + +logger: logging.Logger = logging.getLogger() + + +def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + moving a tensor from cpu to cuda using pinned memory (non_blocking) is generally faster + """ + if is_torchdynamo_compiling(): + # TODO: remove once FakeTensor supports pin_memory() and to(..., non_blocking=True) + return tensor.to(device=device) + + return ( + tensor.pin_memory().to(device=device, non_blocking=True) + if device.type == "cuda" and tensor.device.type == "cpu" + else tensor.to(device=device, non_blocking=True) + ) def _cumsum(o: List[int]) -> List[int]: + """ + python-list version of converting lengths --> offsets + """ ret = [0] * (len(o) + 1) for i in range(len(o)): ret[i + 1] = ret[i] + o[i] @@ -40,6 +84,7 @@ def _to_lengths(offsets: torch.Tensor) -> torch.Tensor: return offsets[1:] - offsets[:-1] +@torch.jit.script_if_tracing def _batched_lengths_to_offsets(lengths: torch.Tensor) -> torch.Tensor: (f, b) = lengths.shape offsets_0 = lengths.new_zeros((f, 1)) @@ -53,7 +98,7 @@ def _maybe_compute_lengths( ) -> torch.Tensor: if lengths is None: assert offsets is not None - lengths = _to_lengths(offsets) + lengths = torch.diff(offsets) return lengths @@ -71,6 +116,29 @@ def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor: return weights +def _get_lengths_offset_per_key_or_throw( + lengths_offset_per_key: Optional[List[int]], +) -> List[int]: + assert ( + lengths_offset_per_key is not None + ), "This (Keyed)JaggedTensor doesn't have lengths_offset_per_key." + return lengths_offset_per_key + + +def _get_stride_per_key_or_throw(stride_per_key: Optional[List[int]]) -> List[int]: + assert ( + stride_per_key is not None + ), "This (Keyed)JaggedTensor doesn't have stride_per_key." + return stride_per_key + + +def _get_inverse_indices_or_throw( + inverse_indices: Optional[Tuple[List[str], torch.Tensor]], +) -> Tuple[List[str], torch.Tensor]: + assert inverse_indices is not None, "This KJT doesn't have inverse indices." + return inverse_indices + + def _assert_offsets_or_lengths_is_provided( offsets: Optional[torch.Tensor], lengths: Optional[torch.Tensor] ) -> None: @@ -78,18 +146,10 @@ def _assert_offsets_or_lengths_is_provided( @torch.fx.wrap +# keep for legacy use cases def _regroup_keyed_tensors( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - # Shortcut for no re-grouping - if len(keyed_tensors) == len(groups): - match = True - for kt, group in zip(keyed_tensors, groups): - if kt.keys() != group: - match = False - break - if match: - return [kt.values() for kt in keyed_tensors] embedding_dicts = [keyed_tensor.to_dict() for keyed_tensor in keyed_tensors] lengths = [keyed_tensor.length_per_key() for keyed_tensor in keyed_tensors] @@ -97,7 +157,7 @@ def _regroup_keyed_tensors( key_dim = keyed_tensors[0].key_dim() key_to_idx: dict[str, int] = {} - for (i, keyed_tensor) in enumerate(keyed_tensors): + for i, keyed_tensor in enumerate(keyed_tensors): for key in keyed_tensor.keys(): key_to_idx[key] = i @@ -115,6 +175,213 @@ def _regroup_keyed_tensors( return list(rearranged_values.split(split_lengths, dim=key_dim)) +@torch.fx.wrap +def _all_keys_used_once( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> bool: + flat_keys: List[str] = [] + flat_groups: List[str] = [] + for keyed_tensor in keyed_tensors: + flat_keys.extend(keyed_tensor.keys()) + for sub_group in groups: + flat_groups.extend(sub_group) + # jit.script does not support set, so we use a dict to represent the set + key_set: Dict[str, int] = {key: 1 for key in flat_keys} + group_set: Dict[str, int] = {key: 1 for key in flat_groups} + return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) + + +@torch.fx.wrap +def permute_multi_embedding( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permutes, in_shape, out_shape, out_lengths = torch.ops.fbgemm.kt_regroup_arguments( + values[0], keys, lengths, groups + ) + permuted_values = torch.ops.fbgemm.permute_multi_embedding( + values, + permutes, + in_shape, + out_shape, + out_lengths, + ) + return permuted_values + + +@torch.fx.wrap +def regroup_kts( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + return torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + + +@torch.fx.wrap +def _fbgemm_permute_pooled_embs( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + keys, lengths, groups + ) + values = torch.concat(values, dim=1) + device = values.device + permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( + values, + _pin_and_move(offsets, device), + _pin_and_move(permute, device), + _pin_and_move(inv_offsets, device), + _pin_and_move(inv_permute, device), + ) + return list(torch.split(permuted_values, splits, dim=1)) + + +@torch.fx.wrap +def _desugar_keyed_tensors( + kts: List["KeyedTensor"], +) -> Tuple[List[List[str]], List[List[int]], List[torch.Tensor]]: + """ + Desugar a list of KeyedTensors into basic data structure + """ + return ( + [kt.keys() for kt in kts], + [kt.length_per_key() for kt in kts], + [kt.values() for kt in kts], + ) + + +@torch.fx.wrap +def _remap_to_groups( + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """ + Given a list of keys and lengths per key for each group, return the permute indices, inverse_permute indices, offsets, inv_offsets, splits. + The output is used to re-arrange values based on groups with a single cat operation. + """ + + lengths: List[int] = [] + flat_keys: List[str] = [] + flat_groups: List[str] = [] + + for sub_keys_length in key_lengths: + lengths.extend(sub_keys_length) + for sub_keys in keys: + flat_keys.extend(sub_keys) + + for sub_group in groups: + flat_groups.extend(sub_group) + + key_splits = [len(sub_group) for sub_group in groups] + + index_map = {key: idx for idx, key in enumerate(flat_keys)} + permute = [index_map[key] for key in flat_groups] + inv_lengths = [lengths[i] for i in permute] + splits = _sum_by_splits(inv_lengths, key_splits) + + inv_permute = [0] * len(permute) + for i, p in enumerate(permute): + inv_permute[p] = i + + offsets = torch.tensor(_cumsum(lengths), dtype=torch.int64) + inv_offsets = torch.tensor(_cumsum(inv_lengths), dtype=torch.int64) + permute = torch.tensor(permute, dtype=torch.int64) + inv_permute = torch.tensor(inv_permute, dtype=torch.int64) + + return permute, inv_permute, offsets, inv_offsets, splits + + +def _kt_regroup_arguments( + value: torch.Tensor, + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """ + returns: permutes, in_shapes, out_shapes, out_lengths + """ + # key => (tensor_idx, key_index) + key_map: Dict[str, Tuple[int, int]] = { + key: (tensor_idx, key_idx) + for tensor_idx, tensor in enumerate(keys) + for key_idx, key in enumerate(tensor) + } + + # [offsets per tensor] + in_offsets: List[List[int]] = [[] for _ in key_lengths] + for i, tensor in enumerate(key_lengths): + in_offsets[i] = _cumsum(tensor) + in_lengths: List[int] = [sum(lengths) for lengths in key_lengths] + + # set total_permutes as the jump stop sign + total_permutes: int = sum(len(tensor) for tensor in groups) + out_lengths: List[int] = [0] * len(groups) + + # [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump] + permute_param = 6 + permutes: List[List[int]] = [[0] * permute_param for _ in range(total_permutes)] + + # record the last seen index, so that can make the jump from last_seen to current + last_seen: Dict[str, int] = {} + permute_idx = 0 + for output_tensor_idx, output_tenser in enumerate(groups): + output_start = 0 + for output_key in output_tenser: + input_tensor_idx, input_key_idx = key_map[output_key] + input_start = in_offsets[input_tensor_idx][input_key_idx] + length = key_lengths[input_tensor_idx][input_key_idx] + + # add jump data + if output_key not in last_seen: + jump = 0 # don't need to jump yet + # positive as a potential jump start + last_seen[output_key] = permute_idx + else: + prev = last_seen[output_key] + if prev >= 0: # positive ==> it's a jump start + # jump to current idx, positive as the jump start + permutes[prev][5] = permute_idx + else: # it's already in a jump sequence, mark as negative + permutes[-prev][5] = -permute_idx + # mark last_seen negative since it's already in jump + last_seen[output_key] = -permute_idx + # it's a potential jump stop + jump = -total_permutes + + permutes[permute_idx][:] = [ + input_tensor_idx, + output_tensor_idx, + input_start, + output_start, + length, + jump, + ] + permute_idx += 1 + output_start += length + out_lengths[output_tensor_idx] = output_start + + permute_tensor = torch.tensor(permutes, dtype=torch.int32) + in_shapes = torch.tensor(in_lengths, dtype=torch.int32) + out_shapes = torch.tensor(out_lengths, dtype=torch.int32) + device = value.device + permute_tensor = _pin_and_move(permute_tensor, device) + in_shapes = _pin_and_move(in_shapes, device) + out_shapes = _pin_and_move(out_shapes, device) + return ( + permute_tensor, + in_shapes, + out_shapes, + out_lengths, + ) + + def _values_string(values: torch.Tensor, start: int, end: int) -> str: size = values.size() if len(size) == 1: @@ -164,8 +431,135 @@ def _arange(*args, **kwargs) -> torch.Tensor: return torch.arange(*args, **kwargs) -# pyre-fixme[11]: Annotation `ProxyableClassMeta` is not defined as a type. -class JaggedTensorMeta(abc.ABCMeta, torch.fx.ProxyableClassMeta): +def _permute_tensor_by_segments( + tensor: torch.Tensor, + segment_sizes: torch.Tensor, + recat: torch.Tensor, + weights: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Permutes a tensor by segments according to recat tensor. + + For variable stride tensors we permute across length per key, which reduces the + number of permute indices and lengthens each sequence. + `keyed_jagged_index_select_dim1` more efficiently parallelizes work for each permute + index and sequence across multiple thread blocks. + + For permuting KJT with weights that are not of float type (i.e. storing + bucketization position tensor of longs in weights), `permute_1D_sparse_data` is used + instead of `keyed_jagged_index_select_dim1` which doesn't support non float weights. + + NOTE: + `keyed_jagged_index_select_dim1` is only supported for CUDA. + """ + if tensor.device.type == "cuda" and ( + weights is None or weights.dtype == torch.float32 + ): + output = torch.ops.fbgemm.keyed_jagged_index_select_dim1( + values=tensor, + lengths=segment_sizes, + offsets=_to_offsets(segment_sizes), + indices=recat, + batch_size=segment_sizes.numel(), + weights=weights, + selected_lengths_sum=output_size, + ) + permuted_tensor = output[0] + permuted_weights = output[2] if weights is not None else None + else: + ( + _, + permuted_tensor, + permuted_weights, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + recat, + segment_sizes, + tensor, + weights, + output_size, + ) + return permuted_tensor, permuted_weights + + +@torch.fx.wrap +def _kjt_concat( + kjt_list: List["KeyedJaggedTensor"], +) -> "KeyedJaggedTensor": + if len(kjt_list) == 0: + raise ValueError("Can't concat empty KJT list") + + is_weighted: bool = kjt_list[0].weights_or_none() is not None + has_length_per_key: bool = True + + length_per_key: List[int] = [] + keys: List[str] = [] + value_list: List[torch.Tensor] = [] + weight_list: List[torch.Tensor] = [] + length_list: List[torch.Tensor] = [] + stride_per_key_per_rank: List[List[int]] = [] + stride: Optional[int] = None + inv_idx_keys: List[str] = [] + inv_idx_tensors: List[torch.Tensor] = [] + + variable_stride_per_key_list = [kjt.variable_stride_per_key() for kjt in kjt_list] + assert all(variable_stride_per_key_list) or not any( + variable_stride_per_key_list + ), "variable stride per key must be consistent for all KJTs" + variable_stride_per_key = all(variable_stride_per_key_list) + + for i, kjt in enumerate(kjt_list): + curr_is_weighted: bool = kjt.weights_or_none() is not None + if is_weighted != curr_is_weighted: + raise ValueError("Can't merge weighted KJT with unweighted KJT") + _length_per_key: Optional[List[int]] = None + if kjt._length_per_key is None: + has_length_per_key = False + else: + _length_per_key = kjt._length_per_key + if has_length_per_key and _length_per_key is not None: + length_per_key += _length_per_key + keys += kjt.keys() + value_list.append(kjt.values()) + if is_weighted: + weight_list.append(kjt.weights()) + length_list.append(kjt.lengths()) + if variable_stride_per_key: + stride_per_key_per_rank += kjt.stride_per_key_per_rank() + elif stride is None: + stride = kjt.stride() + else: + assert stride == kjt.stride(), "strides must be consistent for all KJTs" + if kjt.inverse_indices_or_none() is not None: + assert ( + len(inv_idx_tensors) == i + ), "inverse indices must be consistent for all KJTs" + inv_idx_keys += kjt.inverse_indices()[0] + inv_idx_tensors.append(kjt.inverse_indices()[1]) + else: + assert ( + len(inv_idx_tensors) == 0 + ), "inverse indices must be consistent for all KJTs" + + return KeyedJaggedTensor( + keys=keys, + values=torch.cat(value_list, dim=0), + weights=torch.cat(weight_list, dim=0) if is_weighted else None, + lengths=torch.cat(length_list, dim=0), + stride=stride, + stride_per_key_per_rank=( + stride_per_key_per_rank if variable_stride_per_key else None + ), + length_per_key=length_per_key if has_length_per_key else None, + inverse_indices=( + (inv_idx_keys, torch.cat(inv_idx_tensors)) + if len(inv_idx_tensors) == len(kjt_list) + else None + ), + ) + + +class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): pass @@ -191,6 +585,8 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offsets. """ + _fields = ["_values", "_weights", "_lengths", "_offsets"] + def __init__( self, values: torch.Tensor, @@ -198,6 +594,7 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, ) -> None: + self._values: torch.Tensor = values self._weights: Optional[torch.Tensor] = weights _assert_offsets_or_lengths_is_provided(offsets, lengths) @@ -209,12 +606,33 @@ def __init__( self._offsets: Optional[torch.Tensor] = offsets @staticmethod - def empty(is_weighted: bool = False) -> "JaggedTensor": - weights = torch.tensor([]) if is_weighted else None + def empty( + is_weighted: bool = False, + device: Optional[torch.device] = None, + values_dtype: Optional[torch.dtype] = None, + weights_dtype: Optional[torch.dtype] = None, + lengths_dtype: torch.dtype = torch.int32, + ) -> "JaggedTensor": + """ + Constructs an empty JaggedTensor. + + Args: + is_weighted (bool): whether the JaggedTensor has weights. + device (Optional[torch.device]): device for JaggedTensor. + values_dtype (Optional[torch.dtype]): dtype for values. + weights_dtype (Optional[torch.dtype]): dtype for weights. + lengths_dtype (torch.dtype): dtype for lengths. + + Returns: + JaggedTensor: empty JaggedTensor. + """ + weights = ( + torch.empty(0, dtype=weights_dtype, device=device) if is_weighted else None + ) return JaggedTensor( - values=torch.tensor([]), - offsets=torch.tensor([]), - lengths=torch.tensor([]), + values=torch.empty(0, dtype=values_dtype, device=device), + offsets=torch.empty(0, dtype=lengths_dtype, device=device), + lengths=torch.empty(0, dtype=lengths_dtype, device=device), weights=weights, ) @@ -225,9 +643,17 @@ def from_dense_lengths( weights: Optional[torch.Tensor] = None, ) -> "JaggedTensor": """ - Constructs `JaggedTensor` from dense values/weights of shape (B, N,). + Constructs `JaggedTensor` from values and lengths tensors, with optional weights. + Note that `lengths` is still of shape (B,), where B is the batch size. + + Args: + values (torch.Tensor): dense representation of values. + lengths (torch.Tensor): jagged slices, represented as lengths. + weights (Optional[torch.Tensor]): if values have weights, tensor with + the same shape as values. - Note that `lengths` is still of shape (B,). + Returns: + JaggedTensor: JaggedTensor created from 2D dense tensor. """ mask2d = ( @@ -245,9 +671,9 @@ def from_dense( weights: Optional[List[torch.Tensor]] = None, ) -> "JaggedTensor": """ - Constructs `JaggedTensor` from dense values/weights of shape (B, N,). - - Note that `lengths` and `offsets` are still of shape (B,). + Constructs `JaggedTensor` from list of tensors as values, with optional weights. + `lengths` will be computed, of shape (B,), where B is `len(values)` which + represents the batch size. Args: values (List[torch.Tensor]): a list of tensors for dense representation @@ -276,26 +702,26 @@ def from_dense( weights=weights, ) - # j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]] + # j1 = [[1.0], [], [7.0, 8.0], [10.0, 11.0, 12.0]] """ - lengths = torch.IntTensor([value.size(0) for value in values]) - # pyre-ignore [9]: values is declared to have type `List[Tensor]` but is used as type `Tensor`. - values = torch.cat(values, dim=0) - # pyre-ignore [9]: weights is declared to have type `Optional[List[Tensor]]` but is used as type `Optional[Tensor]`. - weights = torch.cat(weights, dim=0) if weights is not None else None + + values_tensor = torch.cat(values, dim=0) + lengths = torch.tensor( + [value.size(0) for value in values], + dtype=torch.int32, + device=values_tensor.device, + ) + weights_tensor = torch.cat(weights, dim=0) if weights is not None else None return JaggedTensor( - # pyre-fixme[6]: For 1st param expected `Tensor` but got `List[Tensor]`. - values=values, - # pyre-fixme[6]: For 2nd param expected `Optional[Tensor]` but got - # `Optional[List[Tensor]]`. - weights=weights, + values=values_tensor, + weights=weights_tensor, lengths=lengths, ) def to_dense(self) -> List[torch.Tensor]: """ - Constructs dense-reprensentation tensor from JT. + Constructs a dense-representation of the JT's values. Returns: List[torch.Tensor]: list of tensors. @@ -306,9 +732,9 @@ def to_dense(self) -> List[torch.Tensor]: offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor(values=values, offsets=offsets) - torch_list = jt.to_dense() + values_list = jt.to_dense() - # torch_list = [ + # values_list = [ # torch.tensor([1.0, 2.0]), # torch.tensor([]), # torch.tensor([3.0]), @@ -324,16 +750,50 @@ def to_dense(self) -> List[torch.Tensor]: tensor_list.append(self.values()[offset:next_offset]) return tensor_list + def to_dense_weights(self) -> Optional[List[torch.Tensor]]: + """ + Constructs a dense-representation of the JT's weights. + + Returns: + Optional[List[torch.Tensor]]: list of tensors, `None` if no weights. + + Example:: + + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + jt = JaggedTensor(values=values, weights=weights, offsets=offsets) + + weights_list = jt.to_dense_weights() + + # weights_list = [ + # torch.tensor([0.1, 0.2]), + # torch.tensor([]), + # torch.tensor([0.3]), + # torch.tensor([0.4]), + # torch.tensor([0.5]), + # torch.tensor([0.6, 0.7, 0.8]), + # ] + """ + if self.weights_or_none() is None: + return None + tensor_list = [] + for index in range(self.offsets().size(0) - 1): + offset = self.offsets()[index].item() + next_offset = self.offsets()[index + 1].item() + tensor_list.append(self.weights()[offset:next_offset]) + return tensor_list + def to_padded_dense( self, desired_length: Optional[int] = None, padding_value: float = 0.0, ) -> torch.Tensor: """ - Constructs 2D dense Tensor from JT to shape (B, N,). + Constructs a 2D dense tensor from the JT's values of shape (B, N,). - Note that `B` is the length of self.lengths() and `N` is the longest feature - length or `desired_length`. + Note that `B` is the length of self.lengths() and + `N` is the longest feature length or `desired_length`. If `desired_length` > `length` we will pad with `padding_value`, otherwise we will select the last value at `desired_length`. @@ -362,72 +822,191 @@ def to_padded_dense( # [3.0, 10.0], # [4.0, 10.0], # [5.0, 10.0], - # [7.0, 8.0], + # [6.0, 7.0], # ] """ - lengths_list: List[int] = self.lengths().tolist() - N = max(lengths_list) if desired_length is None else desired_length + if desired_length is None: + N = int(torch.max(self.lengths()).item()) + else: + N = desired_length return torch.ops.fbgemm.jagged_to_padded_dense( self.values(), [self.offsets()], [N], padding_value ) + def to_padded_dense_weights( + self, + desired_length: Optional[int] = None, + padding_value: float = 0.0, + ) -> Optional[torch.Tensor]: + """ + Constructs a 2D dense tensor from the JT's weights of shape (B, N,). + + Note that `B` (batch size) is the length of self.lengths() and + `N` is the longest feature length or `desired_length`. + + If `desired_length` > `length` we will pad with `padding_value`, otherwise we + will select the last value at `desired_length`. + + Like `to_padded_dense` but for the JT's weights instead of values. + + Args: + desired_length (int): the length of the tensor. + padding_value (float): padding value if we need to pad. + + Returns: + Optional[torch.Tensor]: 2d dense tensor, `None` if no weights. + + Example:: + + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + jt = JaggedTensor(values=values, weights=weights, offsets=offsets) + + d_wt = jt.to_padded_dense_weights( + desired_length=2, + padding_value=1.0, + ) + + # d_wt = [ + # [0.1, 0.2], + # [1.0, 1.0], + # [0.3, 1.0], + # [0.4, 1.0], + # [0.5, 1.0], + # [0.6, 0.7], + # ] + """ + if self.weights_or_none() is None: + return None + if desired_length is None: + N = int(torch.max(self.lengths()).item()) + else: + N = desired_length + return torch.ops.fbgemm.jagged_to_padded_dense( + self.weights(), [self.offsets()], [N], padding_value + ) + + def device(self) -> torch.device: + """ + Get JaggedTensor device. + + Returns: + torch.device: the device of the values tensor. + """ + return self._values.device + def lengths(self) -> torch.Tensor: + """ + Get JaggedTensor lengths. If not computed, compute it from offsets. + + Returns: + torch.Tensor: the lengths tensor. + """ _lengths = _maybe_compute_lengths(self._lengths, self._offsets) self._lengths = _lengths return _lengths def lengths_or_none(self) -> Optional[torch.Tensor]: + """ + Get JaggedTensor lengths. If not computed, return None. + + Returns: + Optional[torch.Tensor]: the lengths tensor. + """ return self._lengths def offsets(self) -> torch.Tensor: + """ + Get JaggedTensor offsets. If not computed, compute it from lengths. + + Returns: + torch.Tensor: the offsets tensor. + """ _offsets = _maybe_compute_offsets(self._lengths, self._offsets) self._offsets = _offsets return _offsets def offsets_or_none(self) -> Optional[torch.Tensor]: + """ + Get JaggedTensor offsets. If not computed, return None. + + Returns: + Optional[torch.Tensor]: the offsets tensor. + """ return self._offsets def values(self) -> torch.Tensor: + """ + Get JaggedTensor values. + + Returns: + torch.Tensor: the values tensor. + """ return self._values def weights(self) -> torch.Tensor: + """ + Get JaggedTensor weights. If None, throw an error. + + Returns: + torch.Tensor: the weights tensor. + """ return _get_weights_or_throw(self._weights) def weights_or_none(self) -> Optional[torch.Tensor]: + """ + Get JaggedTensor weights. If None, return None. + + Returns: + Optional[torch.Tensor]: the weights tensor. + """ return self._weights def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor": + """ + Move the JaggedTensor to the specified device. + + Args: + device (torch.device): the device to move to. + non_blocking (bool): whether to perform the copy asynchronously. + + Returns: + JaggedTensor: the moved JaggedTensor. + """ weights = self._weights lengths = self._lengths offsets = self._offsets return JaggedTensor( values=self._values.to(device, non_blocking=non_blocking), - weights=weights.to(device, non_blocking=non_blocking) - if weights is not None - else None, - lengths=lengths.to(device, non_blocking=non_blocking) - if lengths is not None - else None, - offsets=offsets.to(device, non_blocking=non_blocking) - if offsets is not None - else None, + weights=( + weights.to(device, non_blocking=non_blocking) + if weights is not None + else None + ), + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), ) @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) weights = self._weights lengths = self._lengths offsets = self._offsets if weights is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. weights.record_stream(stream) if lengths is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. lengths.record_stream(stream) if offsets is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. offsets.record_stream(stream) def __str__(self) -> str: @@ -452,10 +1031,49 @@ def __str__(self) -> str: ) +def _jt_flatten( + t: JaggedTensor, +) -> Tuple[List[Optional[torch.Tensor]], None]: + return [getattr(t, a) for a in JaggedTensor._fields], None + + +def _jt_flatten_with_keys( + t: JaggedTensor, +) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], None]: + values, context = _jt_flatten(t) + # pyre can't tell that GetAttrKey implements the KeyEntry protocol + return [ # pyre-ignore[7] + (GetAttrKey(k), v) for k, v in zip(JaggedTensor._fields, values) + ], context + + +def _jt_unflatten(values: List[Optional[torch.Tensor]], context: None) -> JaggedTensor: + return JaggedTensor(*values) + + +def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Tensor]]: + return [getattr(t, a) for a in JaggedTensor._fields] + + +register_pytree_node( + JaggedTensor, + _jt_flatten, + _jt_unflatten, + flatten_with_keys_fn=_jt_flatten_with_keys, + serialized_type_name="torchrec.sparse.jagged_tensor.JaggedTensor", +) +register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec) + + def _assert_tensor_has_no_elements_or_has_integers( - tensor: torch.Tensor, tensor_name: str + tensor: Optional[torch.Tensor], tensor_name: str ) -> None: - assert tensor.numel() == 0 or tensor.dtype in [ + if is_torchdynamo_compiling() or tensor is None: + # Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes. + # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable. + return + + assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [ torch.long, torch.int, torch.short, @@ -478,10 +1096,13 @@ def _maybe_compute_stride_kjt( stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 + elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: + stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: @@ -491,59 +1112,167 @@ def _maybe_compute_stride_kjt( return stride -# Specialization of _maybe_compute_stride_kjt that is scripted, so it will produce -# correct results in case of usage with jit.tracing. -# This module is returning torch.Tensor instead of int, because ji.trace doesn't -# support int type at the current moment. -@torch.jit.script -def _maybe_compute_stride_kjt_scripted( - keys: List[str], - stride: Optional[int], - lengths: Optional[torch.Tensor], - offsets: Optional[torch.Tensor], -) -> torch.Tensor: - return torch.tensor([_maybe_compute_stride_kjt(keys, stride, lengths, offsets)]) +def _use_segment_sum_csr(stride_per_key: List[int]) -> bool: + """ + `segment_sum_csr` performs poorly for small number of segments and many elements + in each segment to sum. This function uses an empirically calculated equation, + derived from fitting a quadratic regression to an interval of elements and elements + per segment that match performance between the kernel and PyTorch solution, to + determine the threshold of when to use `segment_sum_csr`. + """ + if is_torchdynamo_compiling(): + # dynamo symbolic shapes can not pass this condition without concrete stride values + return False + + elements_per_segment = sum(stride_per_key) / len(stride_per_key) + segment_threshold = int( + 1.39771 + + 0.0000312222 * elements_per_segment + + 1.63949e-10 * elements_per_segment**2 + ) + return len(stride_per_key) >= segment_threshold + + +def _length_per_key_from_stride_per_key( + lengths: torch.Tensor, stride_per_key: List[int] +) -> List[int]: + ret: List[int] = [] + if _use_segment_sum_csr(stride_per_key): + stride_per_key_offsets = _to_offsets( + _pin_and_move( + torch.tensor(stride_per_key, dtype=torch.int32), lengths.device + ) + ) + ret = torch.jit.annotate( + List[int], + torch.ops.fbgemm.segment_sum_csr( + 1, stride_per_key_offsets, lengths + ).tolist(), + ) + else: + tensor_list: List[torch.Tensor] = [ + torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key) + ] + if len(tensor_list) == 0: + return [] + + ret = torch.jit.annotate(List[int], torch.cat(tensor_list).tolist()) + + pt2_checks_all_is_size(ret) + return ret def _maybe_compute_length_per_key( keys: List[str], stride: int, + stride_per_key: List[int], + variable_stride_per_key: bool, length_per_key: Optional[List[int]], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + values: Optional[torch.Tensor], ) -> List[int]: if length_per_key is None: - if len(keys) and offsets is not None and len(offsets) > 0: - _length: List[int] = torch.sum( - torch.diff(offsets).view(-1, stride), dim=1 - ).tolist() + if ( + len(keys) + and values is not None + and values.is_meta + and not is_non_strict_exporting() + ): + # create dummy lengths per key when on meta device + total_length = values.numel() + _length = [total_length // len(keys)] * len(keys) + _length[0] += total_length % len(keys) elif len(keys) and lengths is not None: _length: List[int] = ( - torch.sum(lengths.view(-1, stride), dim=1).tolist() - if lengths.numel() != 0 - else [0] * len(keys) + _length_per_key_from_stride_per_key(lengths, stride_per_key) + if variable_stride_per_key + else ( + torch.sum( + pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1 + ).tolist() + if pt2_guard_size_oblivious(lengths.numel() != 0) + else [0] * len(keys) + ) + ) + elif len(keys) and offsets is not None and len(offsets) > 0: + _length: List[int] = ( + _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) + if variable_stride_per_key + else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() ) else: _length: List[int] = [] length_per_key = _length + pt2_checks_all_is_size(length_per_key) + return length_per_key def _maybe_compute_offset_per_key( keys: List[str], stride: int, + stride_per_key: List[int], + variable_stride_per_key: bool, length_per_key: Optional[List[int]], offset_per_key: Optional[List[int]], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + values: Optional[torch.Tensor], ) -> Tuple[List[int], List[int]]: if length_per_key is None: _length_per_key: List[int] = _maybe_compute_length_per_key( - keys, stride, length_per_key, lengths, offsets + keys=keys, + stride=stride, + stride_per_key=stride_per_key, + variable_stride_per_key=variable_stride_per_key, + length_per_key=length_per_key, + lengths=lengths, + offsets=offsets, + values=values, ) - return _length_per_key, _cumsum(_length_per_key) + + if not torch.jit.is_scripting() and is_non_strict_exporting(): + # only torch.export non-strict case + return ( + _length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + _length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(_length_per_key) > 0 + else [] + ), + ) + else: + return _length_per_key, _cumsum(_length_per_key) elif offset_per_key is None: - return length_per_key, _cumsum(length_per_key) + if not torch.jit.is_scripting() and is_non_strict_exporting(): + # only torch.export non-strict case + return ( + length_per_key, + ( + torch.ops.fbgemm.asynchronous_complete_cumsum( + torch._refs.tensor( + length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).tolist() + if len(length_per_key) > 0 + else [] + ), + ) + else: + return length_per_key, _cumsum(length_per_key) else: return length_per_key, offset_per_key @@ -576,6 +1305,8 @@ def _jagged_tensor_string( class ComputeKJTToJTDict(torch.nn.Module): """Converts a KeyedJaggedTensor to a dict of JaggedTensors. + Args: + Example:: # 0 1 2 <-- dim_1 # "Feature0" [V0,V1] None [V2] @@ -596,6 +1327,7 @@ def forward( ) -> Dict[str, JaggedTensor]: """ Converts a KeyedJaggedTensor into a dict of JaggedTensors. + Args: keyed_jagged_tensor (KeyedJaggedTensor): tensor to convert Returns: @@ -603,59 +1335,119 @@ def forward( """ return _maybe_compute_kjt_to_jt_dict( stride=keyed_jagged_tensor.stride(), + stride_per_key=keyed_jagged_tensor.stride_per_key(), keys=keyed_jagged_tensor.keys(), length_per_key=keyed_jagged_tensor.length_per_key(), values=keyed_jagged_tensor.values(), lengths=keyed_jagged_tensor.lengths(), + variable_stride_per_key=keyed_jagged_tensor.variable_stride_per_key(), weights=keyed_jagged_tensor.weights_or_none(), jt_dict=keyed_jagged_tensor._jt_dict, ) +class ComputeJTDictToKJT(torch.nn.Module): + """Converts a dict of JaggedTensors to KeyedJaggedTensor. + Args: + + Example: + passing in jt_dict + { + "Feature0": JaggedTensor([[V0,V1],None,V2]), + "Feature1": JaggedTensor([V3,V4,[V5,V6,V7]]), + } + Returns:: + kjt with content: + # 0 1 2 <-- dim_1 + # "Feature0" [V0,V1] None [V2] + # "Feature1" [V3] [V4] [V5,V6,V7] + # ^ + # dim_0 + + """ + + def forward(self, jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": + """ + Args: + jt_dict: a dict of JaggedTensor + Returns: + KeyedJaggedTensor + """ + return KeyedJaggedTensor.from_jt_dict(jt_dict) + + +@torch.fx.wrap def _maybe_compute_kjt_to_jt_dict( stride: int, + stride_per_key: List[int], keys: List[str], length_per_key: List[int], values: torch.Tensor, lengths: torch.Tensor, + variable_stride_per_key: bool, weights: Optional[torch.Tensor], jt_dict: Optional[Dict[str, JaggedTensor]], ) -> Dict[str, JaggedTensor]: - if jt_dict is None: - _jt_dict: Dict[str, JaggedTensor] = {} - values_list = torch.split(values, length_per_key) - lengths_tuple = torch.unbind( - lengths.view(-1, stride) if lengths.numel() != 0 else lengths, dim=0 + if not length_per_key: + return {} + + if jt_dict is not None: + return jt_dict + + _jt_dict: Dict[str, JaggedTensor] = {} + if not torch.jit.is_scripting() and is_pt2_compiling(): + cat_size = 0 + total_size = values.size(0) + for i in length_per_key: + cat_size += i + torch._check(cat_size <= total_size) + torch._check(cat_size == total_size) + torch._check_is_size(stride) + values_list = torch.split(values, length_per_key) + if variable_stride_per_key: + split_lengths = torch.split(lengths, stride_per_key) + split_offsets = [ + torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + for lengths in split_lengths + ] + elif pt2_guard_size_oblivious(lengths.numel() > 0): + strided_lengths = lengths.view(len(keys), stride) + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + torch._check(strided_lengths.size(0) > 0) + torch._check(strided_lengths.size(1) > 0) + split_lengths = torch.unbind( + strided_lengths, + dim=0, ) - offsets_tuple = torch.unbind( - _batched_lengths_to_offsets(lengths.view(-1, stride)) - if lengths.numel() != 0 - else lengths, + split_offsets = torch.unbind( + _batched_lengths_to_offsets(strided_lengths), dim=0, ) - - if weights is not None: - weights_list = torch.split(weights, length_per_key) - for idx, key in enumerate(keys): - length = lengths_tuple[idx] - offset = offsets_tuple[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - weights=weights_list[idx], - ) - else: - for idx, key in enumerate(keys): - length = lengths_tuple[idx] - offset = offsets_tuple[idx] - _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - ) - jt_dict = _jt_dict - return jt_dict + else: + split_lengths = torch.unbind(lengths, dim=0) + split_offsets = torch.unbind(lengths, dim=0) + + if weights is not None: + weights_list = torch.split(weights, length_per_key) + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + weights=weights_list[idx], + ) + else: + for idx, key in enumerate(keys): + length = split_lengths[idx] + offset = split_offsets[idx] + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + ) + return _jt_dict @torch.fx.wrap @@ -672,6 +1464,236 @@ def _merge_weights_or_none( return torch.cat([a_weights, b_weights], dim=0) +@torch.fx.wrap +def _strides_from_kjt( + kjt: "KeyedJaggedTensor", +) -> Tuple[Optional[int], Optional[List[List[int]]]]: + stride, stride_per_key_per_rank = ( + (None, kjt.stride_per_key_per_rank()) + if kjt.variable_stride_per_key() + else (kjt.stride(), None) + ) + + return stride, stride_per_key_per_rank + + +@torch.fx.wrap +def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": + # empty like function fx wrapped, also avoids device hardcoding + stride, stride_per_key_per_rank = ( + (None, kjt.stride_per_key_per_rank()) + if kjt.variable_stride_per_key() + else (kjt.stride(), None) + ) + + return KeyedJaggedTensor( + keys=[], + values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype), + weights=( + None + if kjt.weights_or_none() is None + else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype) + ), + lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + +def _sum_by_splits(input_list: List[int], splits: List[int]) -> List[int]: + return [ + sum(input_list[sum(splits[:i]) : sum(splits[:i]) + n]) + for i, n in enumerate(splits) + ] + + +@torch.fx.wrap +def jt_is_equal(jt_1: "JaggedTensor", jt_2: "JaggedTensor") -> bool: + """This function checks if two JaggedTensors are equal by comparing their internal representations. + The comparison is done by comparing the values of the internal representations themselves. + For optional fields, None values are treated as equal. + + Args: + jt_1 (JaggedTensor): the first JaggedTensor + jt_2 (JaggedTensor): the second JaggedTensor + + Returns: + bool: True if both JaggedTensors have the same values + """ + + if not isinstance(jt_1, JaggedTensor) or not isinstance(jt_2, JaggedTensor): + return False + + if not _check_attributes(jt_1.values(), jt_2.values(), torch.allclose): + return False + + _force_length_offset_computation(jt_1) + _force_length_offset_computation(jt_2) + + attributes_to_check = [ + (jt_1.weights_or_none(), jt_2.weights_or_none()), + (jt_1.lengths_or_none(), jt_2.lengths_or_none()), + (jt_1.offsets_or_none(), jt_2.offsets_or_none()), + ] + + for attr_1, attr_2 in attributes_to_check: + if not _check_attributes( + attr_1, + attr_2, + torch.allclose if isinstance(attr_1, torch.Tensor) else operator.eq, + ): + return False + + return True + + +@torch.fx.wrap +def kjt_is_equal(kjt_1: "KeyedJaggedTensor", kjt_2: "KeyedJaggedTensor") -> bool: + """This function checks if two KeyedJaggedTensors are equal by comparing their internal representations. + The comparison is done by comparing the values of the internal representations themselves. + For optional fields, None values are treated as equal. + We compare the keys by ensuring that they have the same length and that the corresponding keys are the same order and same values. + + Args: + kjt_1 (KeyedJaggedTensor): the first KeyedJaggedTensor + kjt_2 (KeyedJaggedTensor): the second KeyedJaggedTensor + + Returns: + bool: True if both KeyedJaggedTensors have the same values + """ + if not isinstance(kjt_1, KeyedJaggedTensor) or not isinstance( + kjt_2, KeyedJaggedTensor + ): + return False + + # check for missing/extra keys + if len(kjt_1.keys()) != len(kjt_2.keys()): + return False + + # check if all keys are equal and in same order + for a, b in zip(kjt_1.keys(), kjt_2.keys()): + if a != b: + return False + + if not _check_attributes(kjt_1.values(), kjt_2.values(), torch.allclose): + return False + + _force_length_offset_computation(kjt_1) + _force_length_offset_computation(kjt_2) + # sync length and offset per key as well + kjt_1.sync() + kjt_2.sync() + + attributes_to_check = [ + (kjt_1.lengths_or_none(), kjt_2.lengths_or_none()), + (kjt_1.weights_or_none(), kjt_2.weights_or_none()), + (kjt_1.offsets_or_none(), kjt_2.offsets_or_none()), + (kjt_1.length_per_key_or_none(), kjt_2.length_per_key_or_none()), + (kjt_1.offset_per_key_or_none(), kjt_2.offset_per_key_or_none()), + (kjt_1.stride(), kjt_2.stride()), + ] + + for attr_1, attr_2 in attributes_to_check: + if not _check_attributes( + attr_1, + attr_2, + torch.allclose if isinstance(attr_1, torch.Tensor) else operator.eq, + ): + return False + + return True + + +def _force_length_offset_computation( + kjt: Union["KeyedJaggedTensor", "JaggedTensor"] +) -> None: + """Helper function to force length/offset computation for KJT or JT + Mainly used for testing equality, as equal KJT's/JT's can be formed from just using lengths or offsets. + One can be derived from the other so to ensure properly equality checking we force the computation of + the other attribute if it can be done. + """ + offsets = kjt.offsets_or_none() + lengths = kjt.lengths_or_none() + if offsets is not None and lengths is None: + kjt.lengths() + elif lengths is not None and offsets is None: + kjt.offsets() + + +def _check_attributes( + attr_1: Union[torch.Tensor, List[int], List[str], int, None], + attr_2: Union[torch.Tensor, List[int], List[str], int, None], + comparison_func: Callable[[Any, Any], bool], # pyre-ignore[2] +) -> bool: + """Helper function to check if two attributes are equal. + + Args: + attr_1: The first attribute. + attr_2: The second attribute. + comparison_func (function): Function to compare the attributes. + + Returns: + bool: False if the attributes are not equal or one is None while the other isn't, otherwise True. + """ + if attr_1 is not None and attr_2 is not None: + # allclose throws error for different tensor sizes, we check manually for this + if ( + comparison_func == torch.allclose + and attr_1.size() != attr_2.size() # pyre-ignore[16] + ): + return False + if not comparison_func(attr_1, attr_2): + return False + elif attr_1 is not None or attr_2 is not None: + return False + + return True + + +def _maybe_compute_lengths_offset_per_key( + lengths_offset_per_key: Optional[List[int]], + stride_per_key: Optional[List[int]], + stride: Optional[int], + keys: List[str], +) -> Optional[List[int]]: + if lengths_offset_per_key is not None: + return lengths_offset_per_key + elif stride_per_key is not None: + return _cumsum(stride_per_key) + elif stride is not None: + return _cumsum([stride] * len(keys)) + else: + return None + + +def _maybe_compute_stride_per_key( + stride_per_key: Optional[List[int]], + stride_per_key_per_rank: Optional[List[List[int]]], + stride: Optional[int], + keys: List[str], +) -> Optional[List[int]]: + if stride_per_key is not None: + return stride_per_key + elif stride_per_key_per_rank is not None: + return [sum(s) for s in stride_per_key_per_rank] + elif stride is not None: + return [stride] * len(keys) + else: + return None + + +def _maybe_compute_variable_stride_per_key( + variable_stride_per_key: Optional[bool], + stride_per_key_per_rank: Optional[List[List[int]]], +) -> bool: + if variable_stride_per_key is not None: + return variable_stride_per_key + elif stride_per_key_per_rank is not None: + return True + else: + return False + + class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): """Represents an (optionally weighted) keyed jagged tensor. @@ -690,11 +1712,19 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offsets (Optional[torch.Tensor]): jagged slices, represented as cumulative offsets. stride (Optional[int]): number of examples per batch. + stride_per_key_per_rank (Optional[List[List[int]]]): batch size + (number of examples) per key per rank, with the outer list representing the + keys and the inner list representing the values. + Each value in the inner list represents the number of examples in the batch + from the rank of its index in a distributed context. length_per_key (Optional[List[int]]): start length for each key. offset_per_key (Optional[List[int]]): start offset for each key and final offset. index_per_key (Optional[Dict[str, int]]): index for each key. - jt_dict (Optional[Dict[str, JaggedTensor]]): + jt_dict (Optional[Dict[str, JaggedTensor]]): dictionary of keys to JaggedTensors. + Allow ability to make to_dict() lazy/cacheable. + inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to + expand deduplicated embedding output for variable stride per key. Example:: @@ -719,6 +1749,15 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offset_per_key: List[int] = [0, 3, 8] # start offset for each key and final offset """ + # This is the subset of fields on KJT which are required (all other fields + # can be derived from these fields, and are only cached) + _fields = [ + "_values", + "_weights", + "_lengths", + "_offsets", + ] + def __init__( self, keys: List[str], @@ -727,35 +1766,59 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, # Below exposed to ensure torch.script-able + stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, + lengths_offset_per_key: Optional[List[int]] = None, offset_per_key: Optional[List[int]] = None, index_per_key: Optional[Dict[str, int]] = None, jt_dict: Optional[Dict[str, JaggedTensor]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> None: + """ + This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible. + It is important only to assign attributes here or do input checks to support various + internal inference optimizations. By convention the attirbute is named same as input arg, just + with leading underscore + """ self._keys: List[str] = keys self._values: torch.Tensor = values self._weights: Optional[torch.Tensor] = weights - if offsets is not None: - _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") - if lengths is not None: - _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets - if torch.jit.is_tracing(): - stride = _maybe_compute_stride_kjt_scripted(keys, stride, lengths, offsets)[ - 0 - ] - else: - stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets) - - self._stride: int = stride - - # lazy fields + self._stride: Optional[int] = stride + self._stride_per_key_per_rank: Optional[List[List[int]]] = ( + stride_per_key_per_rank + ) + self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key + self._lengths_offset_per_key: Optional[List[int]] = lengths_offset_per_key self._index_per_key: Optional[Dict[str, int]] = index_per_key self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict + self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( + inverse_indices + ) + + # legacy attribute, for backward compatabilibity + self._variable_stride_per_key: Optional[bool] = None + + # validation logic + if not torch.jit.is_scripting(): + _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") + _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + self._init_pt2_checks() + + def _init_pt2_checks(self) -> None: + if torch.jit.is_scripting() or not is_torchdynamo_compiling(): + return + if self._stride_per_key is not None: + pt2_checks_all_is_size(self._stride_per_key) + if self._stride_per_key_per_rank is not None: + # pyre-ignore [16] + for s in self._stride_per_key_per_rank: + pt2_checks_all_is_size(s) @staticmethod def from_offsets_sync( @@ -764,13 +1827,36 @@ def from_offsets_sync( offsets: torch.Tensor, weights: Optional[torch.Tensor] = None, stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> "KeyedJaggedTensor": + """ + Constructs a KeyedJaggedTensor from a list of keys, values, and offsets. + + Args: + keys (List[str]): list of keys. + values (torch.Tensor): values tensor in dense representation. + offsets (torch.Tensor): jagged slices, represented as cumulative offsets. + weights (Optional[torch.Tensor]): if the values have weights. Tensor with the + same shape as values. + stride (Optional[int]): number of examples per batch. + stride_per_key_per_rank (Optional[List[List[int]]]): batch size + (number of examples) per key per rank, with the outer list representing the + keys and the inner list representing the values. + inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to + expand deduplicated embedding output for variable stride per key. + + Returns: + KeyedJaggedTensor: constructed KeyedJaggedTensor. + """ kjt = KeyedJaggedTensor( keys=keys, values=values, weights=weights, offsets=offsets, stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, ) return kjt.sync() @@ -781,13 +1867,37 @@ def from_lengths_sync( lengths: torch.Tensor, weights: Optional[torch.Tensor] = None, stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> "KeyedJaggedTensor": + """ + Constructs a KeyedJaggedTensor from a list of keys, lengths, and offsets. + Same as `from_offsets_sync` except lengths are used instead of offsets. + + Args: + keys (List[str]): list of keys. + values (torch.Tensor): values tensor in dense representation. + lengths (torch.Tensor): jagged slices, represented as lengths. + weights (Optional[torch.Tensor]): if the values have weights. Tensor with the + same shape as values. + stride (Optional[int]): number of examples per batch. + stride_per_key_per_rank (Optional[List[List[int]]]): batch size + (number of examples) per key per rank, with the outer list representing the + keys and the inner list representing the values. + inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to + expand deduplicated embedding output for variable stride per key. + + Returns: + KeyedJaggedTensor: constructed KeyedJaggedTensor. + """ kjt = KeyedJaggedTensor( keys=keys, values=values, weights=weights, lengths=lengths, stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, ) return kjt.sync() @@ -795,114 +1905,326 @@ def from_lengths_sync( def concat( kjt_list: List["KeyedJaggedTensor"], ) -> "KeyedJaggedTensor": - if len(kjt_list) == 0: - raise ValueError("Can't concat empty KJT list") - stride: int = kjt_list[0].stride() - is_weighted: bool = kjt_list[0].weights_or_none() is not None - has_length_per_key: bool = True - - length_per_key: List[int] = [] - keys: List[str] = [] - value_list: List[torch.Tensor] = [] - weight_list: List[torch.Tensor] = [] - length_list: List[torch.Tensor] = [] - - for kjt in kjt_list: - if kjt.stride() != stride: - raise ValueError( - f"Can only merge KJTs of the same stride ({stride} != kjt.stride())" - ) - curr_is_weighted: bool = kjt.weights_or_none() is not None - if is_weighted != curr_is_weighted: - raise ValueError("Can't merge weighted KJT with unweighted KJT") - _length_per_key: Optional[List[int]] = None - if kjt._length_per_key is None: - has_length_per_key = False - else: - _length_per_key = kjt._length_per_key - if has_length_per_key and _length_per_key is not None: - length_per_key += _length_per_key - keys += kjt.keys() - value_list.append(kjt.values()) - if is_weighted: - weight_list.append(kjt.weights()) - length_list.append(kjt.lengths()) + """ + Concatenates a list of KeyedJaggedTensors into a single KeyedJaggedTensor. - return KeyedJaggedTensor( - keys=keys, - values=torch.cat(value_list, dim=0), - weights=torch.cat(weight_list, dim=0) if is_weighted else None, - lengths=torch.cat(length_list, dim=0), - stride=stride, - length_per_key=length_per_key if has_length_per_key else None, - ) + Args: + kjt_list (List[KeyedJaggedTensor]): list of KeyedJaggedTensors to be concatenated. + + Returns: + KeyedJaggedTensor: concatenated KeyedJaggedTensor. + """ + return _kjt_concat(kjt_list) @staticmethod def empty( - is_weighted: bool = False, device: Optional[torch.device] = None + is_weighted: bool = False, + device: Optional[torch.device] = None, + values_dtype: Optional[torch.dtype] = None, + weights_dtype: Optional[torch.dtype] = None, + lengths_dtype: torch.dtype = torch.int32, ) -> "KeyedJaggedTensor": - weights = None - if is_weighted is True: - weights = torch.tensor([], device=device) if device else torch.tensor([]) + """ + Constructs an empty KeyedJaggedTensor. + + Args: + is_weighted (bool): whether the KeyedJaggedTensor is weighted or not. + device (Optional[torch.device]): device on which the KeyedJaggedTensor will be placed. + values_dtype (Optional[torch.dtype]): dtype of the values tensor. + weights_dtype (Optional[torch.dtype]): dtype of the weights tensor. + lengths_dtype (torch.dtype): dtype of the lengths tensor. + Returns: + KeyedJaggedTensor: empty KeyedJaggedTensor. + """ + weights = ( + torch.empty(0, dtype=weights_dtype, device=device) if is_weighted else None + ) return KeyedJaggedTensor( - keys=[], - values=torch.tensor([], device=device) if device else torch.tensor([]), + keys=torch.jit.annotate(List[str], []), + values=torch.empty(0, dtype=values_dtype, device=device), weights=weights, - lengths=torch.tensor([], device=device) if device else torch.tensor([]), + lengths=torch.empty(0, dtype=lengths_dtype, device=device), stride=0, ) @staticmethod def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": - return KeyedJaggedTensor( - keys=[], - values=torch.tensor([], device=kjt.device(), dtype=kjt.values().dtype), - weights=None - if kjt.weights_or_none() is None - else torch.tensor([], device=kjt.device(), dtype=kjt.weights().dtype), - lengths=torch.tensor([], device=kjt.device(), dtype=kjt.lengths().dtype), - stride=kjt.stride(), + """ + Constructs an empty KeyedJaggedTensor with the same device and dtypes as the input KeyedJaggedTensor. + + Args: + kjt (KeyedJaggedTensor): input KeyedJaggedTensor. + + Returns: + KeyedJaggedTensor: empty KeyedJaggedTensor. + """ + return _kjt_empty_like(kjt) + + @staticmethod + def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": + """ + Constructs a KeyedJaggedTensor from a dictionary of JaggedTensors. + Automatically calls `kjt.sync()` on newly created KJT. + + NOTE: + This function will ONLY work if the JaggedTensors all + have the same "implicit" batch_size dimension. + + Basically, we can visualize JaggedTensors as 2-D tensors + of the format of [batch_size x variable_feature_dim]. + In the case, we have some batch without a feature value, + the input JaggedTensor could just not include any values. + + But KeyedJaggedTensor (by default) typically pad "None" + so that all the JaggedTensors stored in the KeyedJaggedTensor + have the same batch_size dimension. That is, in the case, + the JaggedTensor input didn't automatically pad + for the empty batches, this function would error / not work. + + Consider the visualization of the following KeyedJaggedTensor: + # 0 1 2 <-- dim_1 + # "Feature0" [V0,V1] None [V2] + # "Feature1" [V3] [V4] [V5,V6,V7] + # ^ + # dim_0 + + Now if the input jt_dict = { + # "Feature0" [V0,V1] [V2] + # "Feature1" [V3] [V4] [V5,V6,V7] + } and the "None" is left out from each JaggedTensor, + then this function would fail as we would not correctly + be able to pad "None" as it does not technically know + the correct batch / place to pad within the JaggedTensor. + + Essentially, the lengths Tensor inferred by this function + would be [2, 1, 1, 1, 3] indicating variable batch_size + dim_1 violates the existing assumption / precondition + that KeyedJaggedTensor's should have fixed batch_size dimension. + + Args: + jt_dict (Dict[str, JaggedTensor]): dictionary of JaggedTensors. + + Returns: + KeyedJaggedTensor: constructed KeyedJaggedTensor. + """ + kjt_keys = list(jt_dict.keys()) + kjt_vals_list: List[torch.Tensor] = [] + kjt_lens_list: List[torch.Tensor] = [] + kjt_weights_list: List[torch.Tensor] = [] + stride_per_key: List[int] = [] + for jt in jt_dict.values(): + stride_per_key.append(len(jt.lengths())) + kjt_vals_list.append(jt.values()) + kjt_lens_list.append(jt.lengths()) + weight = jt.weights_or_none() + if weight is not None: + kjt_weights_list.append(weight) + kjt_vals = torch.concat(kjt_vals_list) + kjt_lens = torch.concat(kjt_lens_list) + kjt_weights = ( + torch.concat(kjt_weights_list) if len(kjt_weights_list) > 0 else None + ) + kjt_stride, kjt_stride_per_key_per_rank = ( + (stride_per_key[0], None) + if all(s == stride_per_key[0] for s in stride_per_key) + else (None, [[stride] for stride in stride_per_key]) ) + kjt = KeyedJaggedTensor( + keys=kjt_keys, + values=kjt_vals, + weights=kjt_weights, + lengths=kjt_lens, + stride=kjt_stride, + stride_per_key_per_rank=kjt_stride_per_key_per_rank, + ).sync() + return kjt def sync(self) -> "KeyedJaggedTensor": - self.length_per_key() - self.offset_per_key() + """ + Synchronizes the KeyedJaggedTensor by computing the offset_per_key and length_per_key. + + Returns: + KeyedJaggedTensor: synced KeyedJaggedTensor. + """ + if not is_torchdynamo_compiling(): + self.length_per_key() + self.offset_per_key() + return self + + def unsync(self) -> "KeyedJaggedTensor": + """ + Unsyncs the KeyedJaggedTensor by clearing the offset_per_key and length_per_key. + + Returns: + KeyedJaggedTensor: unsynced KeyedJaggedTensor. + """ + self._length_per_key = None + self._offset_per_key = None return self def device(self) -> torch.device: + """ + Returns the device of the KeyedJaggedTensor. + + Returns: + torch.device: device of the KeyedJaggedTensor. + """ return self._values.device def lengths(self) -> torch.Tensor: + """ + Returns the lengths of the KeyedJaggedTensor. + If the lengths are not computed yet, it will compute them. + + Returns: + torch.Tensor: lengths of the KeyedJaggedTensor. + """ _lengths = _maybe_compute_lengths(self._lengths, self._offsets) self._lengths = _lengths return _lengths def lengths_or_none(self) -> Optional[torch.Tensor]: + """ + Returns the lengths of the KeyedJaggedTensor or None if they are not computed yet. + + Returns: + torch.Tensor: lengths of the KeyedJaggedTensor. + """ return self._lengths def offsets(self) -> torch.Tensor: + """ + Returns the offsets of the KeyedJaggedTensor. + If the offsets are not computed yet, it will compute them. + + Returns: + torch.Tensor: offsets of the KeyedJaggedTensor. + """ _offsets = _maybe_compute_offsets(self._lengths, self._offsets) self._offsets = _offsets return _offsets def offsets_or_none(self) -> Optional[torch.Tensor]: + """ + Returns the offsets of the KeyedJaggedTensor or None if they are not computed yet. + + Returns: + torch.Tensor: offsets of the KeyedJaggedTensor. + """ return self._offsets def keys(self) -> List[str]: + """ + Returns the keys of the KeyedJaggedTensor. + + Returns: + List[str]: keys of the KeyedJaggedTensor. + """ return self._keys def values(self) -> torch.Tensor: + """ + Returns the values of the KeyedJaggedTensor. + + Returns: + torch.Tensor: values of the KeyedJaggedTensor. + """ return self._values def weights(self) -> torch.Tensor: + """ + Returns the weights of the KeyedJaggedTensor. + If weights is None, this will throw an error. + + Returns: + torch.Tensor: weights of the KeyedJaggedTensor. + """ return _get_weights_or_throw(self._weights) def weights_or_none(self) -> Optional[torch.Tensor]: + """ + Returns the weights of the KeyedJaggedTensor or None if they don't exist. + + Returns: + torch.Tensor: weights of the KeyedJaggedTensor. + """ return self._weights def stride(self) -> int: - return self._stride + """ + Returns the stride of the KeyedJaggedTensor. + If stride is None, this will compute it. + + Returns: + int: stride of the KeyedJaggedTensor. + """ + stride = _maybe_compute_stride_kjt( + self._keys, + self._stride, + self._lengths, + self._offsets, + self._stride_per_key_per_rank, + ) + self._stride = stride + return stride + + def stride_per_key(self) -> List[int]: + """ + Returns the stride per key of the KeyedJaggedTensor. + If stride per key is None, this will compute it. + + Returns: + List[int]: stride per key of the KeyedJaggedTensor. + """ + stride_per_key = _maybe_compute_stride_per_key( + self._stride_per_key, + self._stride_per_key_per_rank, + self.stride(), + self._keys, + ) + self._stride_per_key = stride_per_key + return _get_stride_per_key_or_throw(stride_per_key) + + def stride_per_key_per_rank(self) -> List[List[int]]: + """ + Returns the stride per key per rank of the KeyedJaggedTensor. + + Returns: + List[List[int]]: stride per key per rank of the KeyedJaggedTensor. + """ + stride_per_key_per_rank = self._stride_per_key_per_rank + return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + + def variable_stride_per_key(self) -> bool: + """ + Returns whether the KeyedJaggedTensor has variable stride per key. + + Returns: + bool: whether the KeyedJaggedTensor has variable stride per key. + """ + if self._variable_stride_per_key is not None: + return self._variable_stride_per_key + return self._stride_per_key_per_rank is not None + + def inverse_indices(self) -> Tuple[List[str], torch.Tensor]: + """ + Returns the inverse indices of the KeyedJaggedTensor. + If inverse indices are None, this will throw an error. + + Returns: + Tuple[List[str], torch.Tensor]: inverse indices of the KeyedJaggedTensor. + """ + return _get_inverse_indices_or_throw(self._inverse_indices) + + def inverse_indices_or_none(self) -> Optional[Tuple[List[str], torch.Tensor]]: + """ + Returns the inverse indices of the KeyedJaggedTensor or None if they don't exist. + + Returns: + Optional[Tuple[List[str], torch.Tensor]]: inverse indices of the KeyedJaggedTensor. + """ + return self._inverse_indices def _key_indices(self) -> Dict[str, int]: _index_per_key: Dict[str, int] = _maybe_compute_index_per_key( @@ -913,36 +2235,109 @@ def _key_indices(self) -> Dict[str, int]: return _index_per_key def length_per_key(self) -> List[int]: + """ + Returns the length per key of the KeyedJaggedTensor. + If length per key is None, this will compute it. + + Returns: + List[int]: length per key of the KeyedJaggedTensor. + """ _length_per_key = _maybe_compute_length_per_key( - self._keys, - self.stride(), - self._length_per_key, - self._lengths, - self._offsets, + keys=self._keys, + stride=self.stride(), + stride_per_key=self.stride_per_key(), + variable_stride_per_key=self.variable_stride_per_key(), + length_per_key=self._length_per_key, + lengths=self._lengths, + offsets=self._offsets, + values=self._values, ) self._length_per_key = _length_per_key return _length_per_key def length_per_key_or_none(self) -> Optional[List[int]]: + """ + Returns the length per key of the KeyedJaggedTensor or None if it hasn't been computed. + + Returns: + List[int]: length per key of the KeyedJaggedTensor. + """ return self._length_per_key def offset_per_key(self) -> List[int]: + """ + Returns the offset per key of the KeyedJaggedTensor. + If offset per key is None, this will compute it. + + Returns: + List[int]: offset per key of the KeyedJaggedTensor. + """ _length_per_key, _offset_per_key = _maybe_compute_offset_per_key( - self._keys, - self.stride(), - self._length_per_key, - self._offset_per_key, - self._lengths, - self._offsets, + keys=self._keys, + stride=self.stride(), + stride_per_key=self.stride_per_key(), + variable_stride_per_key=self.variable_stride_per_key(), + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + lengths=self._lengths, + offsets=self._offsets, + values=self._values, ) self._length_per_key = _length_per_key self._offset_per_key = _offset_per_key return _offset_per_key def offset_per_key_or_none(self) -> Optional[List[int]]: + """ + Returns the offset per key of the KeyedJaggedTensor or None if it hasn't been computed. + + Returns: + List[int]: offset per key of the KeyedJaggedTensor. + """ return self._offset_per_key + def lengths_offset_per_key(self) -> List[int]: + """ + Returns the lengths offset per key of the KeyedJaggedTensor. + If lengths offset per key is None, this will compute it. + + Returns: + List[int]: lengths offset per key of the KeyedJaggedTensor. + """ + if self.variable_stride_per_key(): + _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key( + self._lengths_offset_per_key, + self.stride_per_key(), + None, + self._keys, + ) + else: + _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key( + self._lengths_offset_per_key, None, self.stride(), self._keys + ) + + self._lengths_offset_per_key = _lengths_offset_per_key + return _get_lengths_offset_per_key_or_throw(_lengths_offset_per_key) + + def index_per_key(self) -> Dict[str, int]: + """ + Returns the index per key of the KeyedJaggedTensor. + + Returns: + Dict[str, int]: index per key of the KeyedJaggedTensor. + """ + return self._key_indices() + def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: + """ + Splits the KeyedJaggedTensor into a list of KeyedJaggedTensor. + + Args: + segments (List[int]): list of segments. + + Returns: + List[KeyedJaggedTensor]: list of KeyedJaggedTensor. + """ split_list: List[KeyedJaggedTensor] = [] start = 0 start_offset = 0 @@ -952,6 +2347,11 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: end = start + segment end_offset = _offset_per_key[end] keys: List[str] = self._keys[start:end] + stride_per_key_per_rank = ( + self.stride_per_key_per_rank()[start:end] + if self.variable_stride_per_key() + else None + ) if segment == len(self._keys): # no torch slicing required split_list.append( @@ -962,55 +2362,134 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: lengths=self._lengths, offsets=self._offsets, stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=self._length_per_key, + lengths_offset_per_key=None, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, jt_dict=self._jt_dict, + inverse_indices=None, ) ) elif segment == 0: + empty_int_list: List[int] = torch.jit.annotate(List[int], []) split_list.append( KeyedJaggedTensor( keys=keys, values=torch.tensor( - [], device=self.device(), dtype=self._values.dtype - ), - weights=None - if self.weights_or_none() is None - else torch.tensor( - [], + empty_int_list, device=self.device(), - dtype=self.weights().dtype, + dtype=self._values.dtype, + ), + weights=( + None + if self.weights_or_none() is None + else torch.tensor( + empty_int_list, + device=self.device(), + dtype=self.weights().dtype, + ) + ), + lengths=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int + ), + offsets=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int ), - lengths=torch.tensor([], device=self.device(), dtype=torch.int), - offsets=torch.tensor([], device=self.device(), dtype=torch.int), stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=None, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, + inverse_indices=None, ) ) else: split_length_per_key = _length_per_key[start:end] - split_list.append( - KeyedJaggedTensor( - keys=keys, - values=self._values[start_offset:end_offset], - weights=None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset], - lengths=self.lengths()[ - start * self._stride : end * self._stride - ], - offsets=None, - stride=self._stride, - length_per_key=split_length_per_key, - offset_per_key=None, - index_per_key=None, - jt_dict=None, + + if not torch.jit.is_scripting() and is_non_strict_exporting(): + sz = sum(split_length_per_key) + + [torch._check_is_size(length) for length in split_length_per_key] + torch._check(start_offset <= self._values.size(0)) + torch._check(sz <= self._values.size(0)) + torch._check_is_size(start_offset) + + torch._check(start_offset + sz <= self._values.size(0)) + + lengths_start = self.lengths_offset_per_key()[start] + lengths_sz = self.lengths_offset_per_key()[end] - lengths_start + + _lengths = torch.narrow( + self.lengths(), 0, lengths_start, lengths_sz + ) + + if self.weights_or_none() is not None: + torch._check(start_offset + sz <= self.weights().size(0)) + torch._check(start_offset <= self.weights().size(0)) + + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=torch.narrow(self._values, 0, start_offset, sz), + weights=( + None + if self.weights_or_none() is None + else torch.narrow(self.weights(), 0, start_offset, sz) + ), + lengths=_lengths, + offsets=None, + stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, + length_per_key=split_length_per_key, + lengths_offset_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) + ) + else: + pt2_checks_tensor_slice(self._values, start_offset, end_offset) + + lengths_offset_per_key: List[int] = self.lengths_offset_per_key() + pt2_checks_tensor_slice( + self.lengths(), + lengths_offset_per_key[start], + lengths_offset_per_key[end], + ) + + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=self._values[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + lengths_offset_per_key[start] : lengths_offset_per_key[ + end + ] + ], + offsets=None, + stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, + length_per_key=split_length_per_key, + lengths_offset_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) ) - ) start = end start_offset = end_offset return split_list @@ -1018,7 +2497,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: def permute( self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None ) -> "KeyedJaggedTensor": + """ + Permutes the KeyedJaggedTensor. + + Args: + indices (List[int]): list of indices. + indices_tensor (Optional[torch.Tensor]): tensor of indices. + Returns: + KeyedJaggedTensor: permuted KeyedJaggedTensor. + """ if indices_tensor is None: indices_tensor = torch.tensor( indices, dtype=torch.int, device=self.device() @@ -1026,25 +2514,71 @@ def permute( length_per_key = self.length_per_key() permuted_keys: List[str] = [] + permuted_stride_per_key_per_rank: List[List[int]] = [] permuted_length_per_key: List[int] = [] - permuted_lengths_sum = 0 + permuted_length_per_key_sum = 0 for index in indices: - key = self._keys[index] + key = self.keys()[index] permuted_keys.append(key) - permuted_lengths_sum += length_per_key[index] permuted_length_per_key.append(length_per_key[index]) - ( - permuted_lengths, - permuted_values, - permuted_weights, - ) = torch.ops.fbgemm.permute_2D_sparse_data( - indices_tensor, - self.lengths().view(len(self._keys), -1), - self.values(), - self.weights_or_none(), - permuted_lengths_sum, - ) + if self.variable_stride_per_key(): + permuted_stride_per_key_per_rank.append( + self.stride_per_key_per_rank()[index] + ) + + permuted_length_per_key_sum = sum(permuted_length_per_key) + if not torch.jit.is_scripting() and is_non_strict_exporting(): + torch._check_is_size(permuted_length_per_key_sum) + torch._check(permuted_length_per_key_sum != -1) + torch._check(permuted_length_per_key_sum != 0) + if self.variable_stride_per_key(): + length_per_key_tensor = _pin_and_move( + torch.tensor(self.length_per_key()), self.device() + ) + stride_per_key_tensor = _pin_and_move( + torch.tensor(self.stride_per_key()), self.device() + ) + permuted_lengths, _ = _permute_tensor_by_segments( + self.lengths(), + stride_per_key_tensor, + indices_tensor, + None, + ) + permuted_values, permuted_weights = _permute_tensor_by_segments( + self.values(), + length_per_key_tensor, + indices_tensor, + self.weights_or_none(), + ) + elif is_torchdynamo_compiling() and not torch.jit.is_scripting(): + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.values(), + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + else: + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.values(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + stride_per_key_per_rank = ( + permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None + ) kjt = KeyedJaggedTensor( keys=permuted_keys, values=permuted_values, @@ -1052,14 +2586,48 @@ def permute( lengths=permuted_lengths.view(-1), offsets=None, stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, + inverse_indices=None, ) return kjt + def flatten_lengths(self) -> "KeyedJaggedTensor": + stride_per_key_per_rank = ( + self._stride_per_key_per_rank if self.variable_stride_per_key() else None + ) + return KeyedJaggedTensor( + keys=self._keys, + values=self._values, + weights=self._weights, + lengths=self.lengths().view(-1), + offsets=None, + stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, + length_per_key=self.length_per_key(), + lengths_offset_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) + def __getitem__(self, key: str) -> JaggedTensor: + """ + Returns the JaggedTensor for the given key. + + Args: + key (str): key. + + Returns: + JaggedTensor: JaggedTensor for the given key. + """ offset_per_key = self.offset_per_key() index = self._key_indices()[key] start_offset = offset_per_key[index] @@ -1068,73 +2636,170 @@ def __getitem__(self, key: str) -> JaggedTensor: if index + 1 < len(offset_per_key) else start_offset ) - return JaggedTensor( - values=self._values[start_offset:end_offset], - weights=None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset], - lengths=self.lengths()[index * self._stride : (index + 1) * self._stride], - offsets=None, - ) + + if not torch.jit.is_scripting() and is_non_strict_exporting(): + length_per_key = self.length_per_key() + _lengths = torch.narrow( + self.lengths(), + 0, + self.lengths_offset_per_key()[index], + self.lengths_offset_per_key()[index + 1] + - self.lengths_offset_per_key()[index], + ) + sz = length_per_key[index] + + torch._check_is_size(start_offset) + torch._check_is_size(sz) + torch._check(start_offset <= self.values().size(0)) + torch._check(sz <= self.values().size(0)) + + if self.weights_or_none() is not None: + torch._check(start_offset <= self.weights().size(0)) + torch._check(sz <= self.weights().size(0)) + + return JaggedTensor( + values=torch.narrow( + self.values(), + 0, + start_offset, + sz, + ), + weights=( + None + if self.weights_or_none() is None + else torch.narrow( + self.weights(), + 0, + start_offset, + sz, + ) + ), + lengths=_lengths, + offsets=None, + ) + else: + pt2_checks_tensor_slice(self._values, start_offset, end_offset) + + return JaggedTensor( + values=self._values[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + self.lengths_offset_per_key()[ + index + ] : self.lengths_offset_per_key()[index + 1] + ], + offsets=None, + ) def to_dict(self) -> Dict[str, JaggedTensor]: + """ + Returns a dictionary of JaggedTensor for each key. + Will cache result in self._jt_dict. + + Returns: + Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key. + """ + if not torch.jit.is_scripting() and is_non_strict_exporting(): + logger.warn( + "Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!" + ) _jt_dict = _maybe_compute_kjt_to_jt_dict( - self.stride(), - self.keys(), - self.length_per_key(), - self.values(), - self.lengths(), - self.weights_or_none(), - self._jt_dict, + stride=self.stride(), + stride_per_key=self.stride_per_key(), + keys=self.keys(), + length_per_key=self.length_per_key(), + lengths=self.lengths(), + values=self.values(), + variable_stride_per_key=self.variable_stride_per_key(), + weights=self.weights_or_none(), + jt_dict=self._jt_dict, ) self._jt_dict = _jt_dict return _jt_dict @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) weights = self._weights lengths = self._lengths offsets = self._offsets if weights is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. weights.record_stream(stream) if lengths is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. lengths.record_stream(stream) if offsets is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. offsets.record_stream(stream) def to( - self, device: torch.device, non_blocking: bool = False + self, + device: torch.device, + non_blocking: bool = False, + dtype: Optional[torch.dtype] = None, ) -> "KeyedJaggedTensor": + """ + Returns a copy of KeyedJaggedTensor in the specified device and dtype. + + Args: + device (torch.device): the desired device of the copy. + non_blocking (bool): whether to copy the tensors in a non-blocking fashion. + dtype (Optional[torch.dtype]): the desired data type of the copy. + + Returns: + KeyedJaggedTensor: the copied KeyedJaggedTensor. + """ weights = self._weights lengths = self._lengths offsets = self._offsets + stride_per_key_per_rank = ( + self._stride_per_key_per_rank if self.variable_stride_per_key() else None + ) length_per_key = self._length_per_key + lengths_offset_per_key = self._lengths_offset_per_key offset_per_key = self._offset_per_key index_per_key = self._index_per_key + stride_per_key = self._stride_per_key jt_dict = self._jt_dict + inverse_indices = self._inverse_indices + if inverse_indices is not None: + inverse_indices = ( + inverse_indices[0], + inverse_indices[1].to(device, non_blocking=non_blocking), + ) + if weights is not None: + if dtype is not None: + weights = weights.to( + dtype=dtype, device=device, non_blocking=non_blocking + ) + else: + weights = weights.to(device=device, non_blocking=non_blocking) return KeyedJaggedTensor( keys=self._keys, values=self._values.to(device, non_blocking=non_blocking), - weights=weights.to(device, non_blocking=non_blocking) - if weights is not None - else None, - lengths=lengths.to(device, non_blocking=non_blocking) - if lengths is not None - else None, - offsets=offsets.to(device, non_blocking=non_blocking) - if offsets is not None - else None, + weights=weights, + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=stride_per_key, length_per_key=length_per_key, + lengths_offset_per_key=lengths_offset_per_key, offset_per_key=offset_per_key, index_per_key=index_per_key, jt_dict=jt_dict, + inverse_indices=inverse_indices, ) def __str__(self) -> str: @@ -1142,7 +2807,6 @@ def __str__(self) -> str: return "KeyedJaggedTensor()\n" offsets = self.offsets() - step = (len(offsets) - 1) // len(self._keys) return ( "KeyedJaggedTensor({\n" + ",\n".join( @@ -1153,8 +2817,8 @@ def __str__(self) -> str: self._values, self._weights, offsets, - index * step, - (index + 1) * step, + sum(self.stride_per_key()[:index]), + sum(self.stride_per_key()[: index + 1]), ) for index in range(len(self._keys)) ] @@ -1166,6 +2830,12 @@ def pin_memory(self) -> "KeyedJaggedTensor": weights = self._weights lengths = self._lengths offsets = self._offsets + stride_per_key_per_rank = ( + self._stride_per_key_per_rank if self.variable_stride_per_key() else None + ) + inverse_indices = self._inverse_indices + if inverse_indices is not None: + inverse_indices = (inverse_indices[0], inverse_indices[1].pin_memory()) return KeyedJaggedTensor( keys=self._keys, @@ -1174,12 +2844,255 @@ def pin_memory(self) -> "KeyedJaggedTensor": lengths=lengths.pin_memory() if lengths is not None else None, offsets=offsets.pin_memory() if offsets is not None else None, stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=self._stride_per_key, length_per_key=self._length_per_key, + lengths_offset_per_key=self._lengths_offset_per_key, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, jt_dict=None, + inverse_indices=inverse_indices, + ) + + def dist_labels(self) -> List[str]: + labels = ["lengths", "values"] + if self.variable_stride_per_key(): + labels.append("strides") + if self.weights_or_none() is not None: + labels.append("weights") + return labels + + def dist_splits(self, key_splits: List[int]) -> List[List[int]]: + batch_size_per_split = _sum_by_splits(self.stride_per_key(), key_splits) + length_per_split = _sum_by_splits(self.length_per_key(), key_splits) + splits = [batch_size_per_split, length_per_split] + if self.variable_stride_per_key(): + splits.append(key_splits) + if self.weights_or_none() is not None: + splits.append(length_per_split) + return splits + + def dist_tensors(self) -> List[torch.Tensor]: + tensors = [self.lengths(), self.values()] + if self.variable_stride_per_key(): + strides = _pin_and_move(torch.tensor(self.stride_per_key()), self.device()) + tensors.append(strides) + if self.weights_or_none() is not None: + tensors.append(self.weights()) + return tensors + + @staticmethod + def dist_init( + keys: List[str], + tensors: List[torch.Tensor], + variable_stride_per_key: bool, + num_workers: int, + recat: Optional[torch.Tensor], + stride_per_rank: Optional[List[int]], + stagger: int = 1, + ) -> "KeyedJaggedTensor": + assert len(tensors) in [2, 3, 4] + lengths = tensors[0] + values = tensors[1] + stride_per_rank_per_key = tensors[2] if variable_stride_per_key else None + weights = ( + tensors[-1] + if (variable_stride_per_key and len(tensors) == 4) + or (not variable_stride_per_key and len(tensors) == 3) + else None ) + if variable_stride_per_key: + assert stride_per_rank_per_key is not None + stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view( + num_workers, len(keys) + ).T.cpu() + + strides_cumsum: torch.Tensor = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(stride_per_rank_per_key) + ).cpu() + + cumsum_lengths = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + + n = strides_cumsum.size(0) + strides_cumsum_from_1 = torch.narrow( + strides_cumsum, dim=0, start=1, length=n - 1 + ) + strides_cumsum_to_minus_1 = torch.narrow( + strides_cumsum, dim=0, start=0, length=n - 1 + ) + length_per_key_tensor = ( + cumsum_lengths[strides_cumsum_from_1] + - cumsum_lengths[strides_cumsum_to_minus_1] + ) + + with record_function("## all2all_data:recat_values ##"): + if recat is not None: + lengths, _ = _permute_tensor_by_segments( + lengths, + stride_per_rank_per_key, + torch.jit._unwrap_optional(recat), + None, + ) + values, weights = _permute_tensor_by_segments( + values, + length_per_key_tensor, + torch.jit._unwrap_optional(recat), + weights, + ) + + stride_per_key_per_rank = torch.jit.annotate( + List[List[int]], stride_per_key_per_rank_tensor.tolist() + ) + + if not stride_per_key_per_rank: + stride_per_key_per_rank = [[0]] * len(keys) + if stagger > 1: + stride_per_key_per_rank_stagger: List[List[int]] = [] + local_world_size = num_workers // stagger + for i in range(len(keys)): + stride_per_rank_stagger: List[int] = [] + for j in range(local_world_size): + stride_per_rank_stagger.extend( + stride_per_key_per_rank[i][j::local_world_size] + ) + stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) + stride_per_key_per_rank = stride_per_key_per_rank_stagger + + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + return kjt.sync() + else: + assert stride_per_rank is not None + with record_function("## all2all_data:recat_values ##"): + if recat is not None: + stride = stride_per_rank[0] + + single_batch_per_rank = True + if not is_torchdynamo_compiling(): + single_batch_per_rank = all( + s == stride for s in stride_per_rank + ) + if ( + single_batch_per_rank + and is_torchdynamo_compiling() + and not torch.jit.is_scripting() + ): + ( + lengths, + values, + weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + torch.jit._unwrap_optional(recat), + lengths, + values, + stride, + weights, + values.numel(), + ) + elif single_batch_per_rank: + ( + lengths, + values, + weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1, stride), + values, + weights, + values.numel(), + ) + lengths = lengths.view(-1) + else: # variable batch size per rank + ( + lengths, + values, + weights, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1), + values, + weights, + values.numel(), + ) + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + stride=sum(stride_per_rank), + ) + return kjt.sync() + + +def _kjt_flatten( + t: KeyedJaggedTensor, +) -> Tuple[List[Optional[torch.Tensor]], List[str]]: + return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys + + +def _kjt_flatten_with_keys( + t: KeyedJaggedTensor, +) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]: + values, context = _kjt_flatten(t) + # pyre can't tell that GetAttrKey implements the KeyEntry protocol + return [ # pyre-ignore[7] + (GetAttrKey(k), v) for k, v in zip(KeyedJaggedTensor._fields, values) + ], context + + +def _kjt_unflatten( + values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys +) -> KeyedJaggedTensor: + return KeyedJaggedTensor(context, *values) + + +def _kjt_flatten_spec( + t: KeyedJaggedTensor, spec: TreeSpec +) -> List[Optional[torch.Tensor]]: + return [getattr(t, a) for a in KeyedJaggedTensor._fields] + + +register_pytree_node( + KeyedJaggedTensor, + _kjt_flatten, + _kjt_unflatten, + flatten_with_keys_fn=_kjt_flatten_with_keys, + serialized_type_name="torchrec.sparse.jagged_tensor.KeyedJaggedTensor", +) +register_pytree_flatten_spec(KeyedJaggedTensor, _kjt_flatten_spec) + + +def flatten_kjt_list( + kjt_arr: List[KeyedJaggedTensor], +) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]: + _flattened_data = [] + _flattened_context = [] + for t in kjt_arr: + _values, _context = _kjt_flatten(t) + _flattened_data.extend(_values) + _flattened_context.append(_context) + return _flattened_data, _flattened_context + + +def unflatten_kjt_list( + values: List[Optional[torch.Tensor]], contexts: List[List[str]] +) -> List[KeyedJaggedTensor]: + num_kjt_fields = len(KeyedJaggedTensor._fields) + length = len(values) + return [ + _kjt_unflatten( + values[j * num_kjt_fields : (j + 1) * num_kjt_fields], + contexts[j], + ) + for j in range(length // num_kjt_fields) + ] + def _maybe_compute_offset_per_key_kt( length_per_key: List[int], @@ -1235,16 +3148,16 @@ class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta): kt = KeyedTensor.from_tensor_list(keys, tensor_list) kt.values() - # tensor( - # [ - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # ] - # ) + # torch.Tensor( + # [ + # [1, 1, 2, 1, 2, 3, 1, 2, 3], + # [1, 1, 2, 1, 2, 3, 1, 2, 3], + # [1, 1, 2, 1, 2, 3, 1, 2, 3], + # ] + # ) kt["Embedding B"] - # tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]]) + # torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]]) """ def __init__( @@ -1269,6 +3182,20 @@ def __init__( def from_tensor_list( keys: List[str], tensors: List[torch.Tensor], key_dim: int = 1, cat_dim: int = 1 ) -> "KeyedTensor": + """ + Create a KeyedTensor from a list of tensors. The tensors are concatenated + along the cat_dim. The keys are used to index the tensors. + + Args: + keys (List[str]): list of keys. + tensors (List[torch.Tensor]): list of tensors. + key_dim (int): key dimension, zero indexed - defaults to 1 + (typically B is 0-dimension). + cat_dim (int): dimension along which to concatenate the tensors - defaults + + Returns: + KeyedTensor: keyed tensor. + """ length_per_key = [tensor.shape[key_dim] for tensor in tensors] return KeyedTensor( keys=keys, @@ -1278,15 +3205,43 @@ def from_tensor_list( ) def keys(self) -> List[str]: + """ + Returns: + List[str]: list of keys. + """ return self._keys def values(self) -> torch.Tensor: + """ + Get the values tensor. + + Returns: + torch.Tensor: dense tensor, concatenated typically along key dimension. + """ return self._values def key_dim(self) -> int: + """ + Returns: + int: key dimension, zero indexed - typically B is 0-dimension. + """ return self._key_dim + def device(self) -> torch.device: + """ + Returns: + torch.device: device of the values tensor. + """ + return self._values.device + def offset_per_key(self) -> List[int]: + """ + Get the offset of each key along key dimension. + Compute and cache if not already computed. + + Returns: + List[int]: offset of each key along key dimension. + """ _offset_per_key = _maybe_compute_offset_per_key_kt( self._length_per_key, self._offset_per_key, @@ -1295,9 +3250,20 @@ def offset_per_key(self) -> List[int]: return _offset_per_key def length_per_key(self) -> List[int]: + """ + Returns: + List[int]: length of each key along key dimension. + """ return self._length_per_key def _key_indices(self) -> Dict[str, int]: + """ + Get the indices of each key. + Compute and cache if not already computed. + + Returns: + Dict[str, int]: indices of each key. + """ _index_per_key = _maybe_compute_index_per_key( self._keys, self._index_per_key, @@ -1306,12 +3272,20 @@ def _key_indices(self) -> Dict[str, int]: return _index_per_key def __getitem__(self, key: str) -> torch.Tensor: + """ + Returns: + torch.Tensor: tensor for the given key. + """ index = self._key_indices()[key] start = self.offset_per_key()[index] length = self._length_per_key[index] return self._values.narrow(dim=self._key_dim, start=start, length=length) def to_dict(self) -> Dict[str, torch.Tensor]: + """ + Returns: + Dict[str, torch.Tensor]: dictionary of tensors keyed by the keys. + """ indices = self._key_indices() lengths = self._length_per_key split_values = self._values.split(lengths, dim=self._key_dim) @@ -1321,25 +3295,60 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - return _regroup_keyed_tensors(keyed_tensors, groups) + """ + Regroup a list of KeyedTensors into a list of tensors. + + Args: + keyed_tensors (List[KeyedTensor]): list of KeyedTensors. + groups (List[List[str]]): list of groups of keys. + + Returns: + List[torch.Tensor]: list of tensors. + """ + # Fast path, one-to-one correspondence between keyed_tensors and groups + if _all_keys_used_once(keyed_tensors, groups) is True: + return _fbgemm_permute_pooled_embs(keyed_tensors, groups) + else: # Fallback to slow path otherwise + return _regroup_keyed_tensors(keyed_tensors, groups) @staticmethod def regroup_as_dict( keyed_tensors: List["KeyedTensor"], groups: List[List[str]], keys: List[str] ) -> Dict[str, torch.Tensor]: + """ + Regroup a list of KeyedTensors into a dictionary of tensors. + + Args: + keyed_tensors (List[KeyedTensor]): list of KeyedTensors. + groups (List[List[str]]): list of groups of keys. + keys (List[str]): list of keys. + + Returns: + Dict[str, torch.Tensor]: dictionary of tensors. + """ + ret: Dict[str, torch.Tensor] = {} assert len(groups) == len(keys), "Groups and keys should have same length" - embeddings_list = _regroup_keyed_tensors(keyed_tensors, groups) - embeddings_dict: Dict[str, torch.Tensor] = {} + tensor_list = KeyedTensor.regroup(keyed_tensors, groups) for i, key in enumerate(keys): - embeddings_dict[key] = embeddings_list[i] - return embeddings_dict + ret[key] = tensor_list[i] + return ret @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) def to(self, device: torch.device, non_blocking: bool = False) -> "KeyedTensor": + """ + Moves the values tensor to the specified device. + + Args: + device (torch.device): device to move the values tensor to. + non_blocking (bool): whether to perform the operation asynchronously + (default: False). + + Returns: + KeyedTensor: keyed tensor with values tensor moved to the specified device. + """ return KeyedTensor( keys=self._keys, length_per_key=self._length_per_key, @@ -1363,3 +3372,34 @@ def __str__(self) -> str: ) + "\n})\n" ) + + +def _kt_flatten( + kt: KeyedTensor, +) -> Tuple[List[torch.Tensor], Tuple[List[str], List[int]]]: + return [kt._values], (kt._keys, kt._length_per_key) + + +def _kt_unflatten( + values: List[torch.Tensor], context: Tuple[List[str], List[int]] +) -> KeyedTensor: + return KeyedTensor(context[0], context[1], values[0]) + + +def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: + _keys, _length_per_key = spec.context + # please read https://fburl.com/workplace/8bei5iju for more context, + # you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict + logger.warning( + "KT's key order might change from spec from the torch.export, this could have perf impact. " + f"{kt.keys()} vs {_keys}" + ) + res = permute_multi_embedding([kt], [_keys]) + return [res[0]] + + +# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static +register_pytree_node( + KeyedTensor, _kt_flatten, _kt_unflatten, serialized_type_name="KeyedTensor" +) +register_pytree_flatten_spec(KeyedTensor, _kt_flatten_spec) diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py new file mode 100644 index 000000000..3f00d5275 --- /dev/null +++ b/torchrec/sparse/tensor_dict.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Optional + +import torch +from tensordict import TensorDict + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def maybe_td_to_kjt( + features: KeyedJaggedTensor, keys: Optional[List[str]] = None +) -> KeyedJaggedTensor: + if torch.jit.is_scripting(): + assert isinstance(features, KeyedJaggedTensor) + return features + if isinstance(features, TensorDict): + if keys is None: + keys = list(features.keys()) + values = torch.cat([features[key]._values for key in keys], dim=0) + lengths = torch.cat( + [ + ( + (features[key]._lengths) + if features[key]._lengths is not None + else torch.diff(features[key]._offsets) + ) + for key in keys + ], + dim=0, + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + ) + else: + return features diff --git a/torchrec/sparse/test_utils/__init__.py b/torchrec/sparse/test_utils/__init__.py index e48f430f4..b8c68ca16 100644 --- a/torchrec/sparse/test_utils/__init__.py +++ b/torchrec/sparse/test_utils/__init__.py @@ -5,41 +5,83 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + +import math from typing import Optional import torch from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -def keyed_jagged_tensor_equals( - kjt1: Optional[KeyedJaggedTensor], kjt2: Optional[KeyedJaggedTensor] +def _tensor_eq_or_none( + t1: Optional[torch.Tensor], + t2: Optional[torch.Tensor], + out_of_order: bool = False, + length: Optional[torch.Tensor] = None, ) -> bool: - def _tensor_eq_or_none( - t1: Optional[torch.Tensor], t2: Optional[torch.Tensor] - ) -> bool: - if t1 is None and t2 is None: - return True - elif t1 is None and t2 is not None: - return False - elif t1 is not None and t2 is None: - return False + if t1 is None and t2 is None: + return True + elif t1 is None and t2 is not None: + return False + elif t1 is not None and t2 is None: + return False + + assert t1 is not None + assert t2 is not None + + if t1.dtype != t2.dtype: + return False + + if not out_of_order: + return torch.equal(t1, t2) + + assert length is not None + is_int = not torch.is_floating_point(t1) + vals_1 = t1.tolist() + vals_2 = t2.tolist() + current_offset = 0 + for i in length.tolist(): + if i == 0: + continue + sorted_vals_1 = sorted(vals_1[current_offset : current_offset + i]) + sorted_vals_2 = sorted(vals_2[current_offset : current_offset + i]) + if is_int: + if sorted_vals_1 != sorted_vals_2: + return False else: - assert t1 is not None - assert t2 is not None - return torch.equal(t1, t2) and t1.dtype == t2.dtype + for left, right in zip( + sorted_vals_1, + sorted_vals_2, + ): + if not math.isclose(left, right): + return False + current_offset += i + return True + +def keyed_jagged_tensor_equals( + kjt1: Optional[KeyedJaggedTensor], + kjt2: Optional[KeyedJaggedTensor], + is_pooled_features: bool = False, +) -> bool: if kjt1 is None and kjt2 is None: return True elif kjt1 is None and kjt2 is not None: return False elif kjt1 is not None and kjt2 is None: return False - else: - assert kjt1 is not None - assert kjt2 is not None - return ( - kjt1.keys() == kjt2.keys() - and _tensor_eq_or_none(kjt1.lengths(), kjt2.lengths()) - and _tensor_eq_or_none(kjt1.values(), kjt2.values()) - and _tensor_eq_or_none(kjt1._weights, kjt2._weights) - ) + + assert kjt1 is not None + assert kjt2 is not None + if not ( + kjt1.keys() == kjt2.keys() + and _tensor_eq_or_none(kjt1.lengths(), kjt2.lengths()) + ): + return False + + return _tensor_eq_or_none( + kjt1.values(), kjt2.values(), is_pooled_features, kjt1.lengths() + ) and _tensor_eq_or_none( + kjt1._weights, kjt2._weights, is_pooled_features, kjt1.lengths() + ) diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py new file mode 100644 index 000000000..34862e380 --- /dev/null +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import functools +import timeit +from typing import Any, Callable, Dict, List + +import click + +import torch +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark, + BenchmarkResult, + MemoryStats, +) +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, + _regroup_keyed_tensors, + KeyedJaggedTensor, + KeyedTensor, + permute_multi_embedding, + regroup_kts, +) +from torchrec.sparse.tests.utils import build_groups, build_kts + + +class DummyModel(torch.nn.Module): + # pyre-ignore + def forward(self, *args, **kwargs) -> None: + pass + + +def bench( + name: str, + labels: torch.Tensor, + batch_size: int, + feature_count: int, + device_type: str, + run_backward: bool, + fn: Callable[..., List[torch.Tensor]], + fn_kwargs: Dict[str, Any], + output_dir: str = "", +) -> None: + + # initial call + fn(**fn_kwargs) + + def wrapped_func( + model: torch.nn.Module, # not used + bench_inputs: List[KeyedJaggedTensor], # not used + fn: Callable[..., List[torch.Tensor]], + run_backward: bool, + **kwargs: Dict[str, Any], + ) -> None: + result = fn(**fn_kwargs) + if run_backward: + if isinstance(result, dict): + vectors = [tensor.sum(dim=1) for tensor in result.values()] + else: + vectors = [tensor.sum(dim=1) for tensor in result] + pred = vectors[0] + for vector in vectors[1:]: + pred.mul(vector) + loss = torch.nn.functional.l1_loss(pred, labels) + loss.sum().backward() + + model = DummyModel() + setattr(model, "forward", lambda kwargs: fn(**kwargs)) + prof_num = 10 + if device_type == "cuda": + result = benchmark( + name=name, + model=model, + warmup_inputs=[], + bench_inputs=[], + prof_inputs=[fn_kwargs] * prof_num, + world_size=1, + output_dir=output_dir, + num_benchmarks=20, + func_to_benchmark=functools.partial( + wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs + ), + benchmark_func_kwargs={}, + rank=0, + enable_logging=True, + ) + + else: # cpu + times = timeit.repeat( + lambda: wrapped_func( + model=model, + bench_inputs=[], + fn=fn, + fn_kwargs=fn_kwargs, + run_backward=run_backward, + ), + number=1, + repeat=20, + ) + result = BenchmarkResult( + short_name=name, + elapsed_time=torch.tensor(times) * 1e3, + mem_stats=[MemoryStats(0, 0, 0, 0)], + ) + + mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}" + mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}" + print( + f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}" + ) + + +@click.command() +@click.option( + "--cuda_matrix", + type=bool, + default=False, + help="Run a full GPU matrix, overrides relevant settings", +) +@click.option( + "--run_backward", + type=bool, + default=False, + help="run backward (forward always runs)", +) +@click.option( + "--device_type", + type=str, + default="cuda", + help="device type", +) +@click.option( + "--n_dense", + type=int, + default=20, + help="Total number of dense embeddings.", +) +@click.option( + "--dim_dense", + type=int, + default=64, + help="Dim dense embedding.", +) +@click.option( + "--n_sparse", + default=1000, + help="Total number of sparse embeddings to be used.", +) +@click.option( + "--dim_sparse", + type=int, + default=128, + help="Dim dense embedding.", +) +@click.option( + "--batch_size", + type=int, + default=1024, + help="Batch size.", +) +@click.option( + "--n_groups", + type=int, + default=2, + help="Total num of regrouping", +) +@click.option( + "--profile", + type=str, + default="", + help="profile output directory", +) +def main( + cuda_matrix: bool, + run_backward: bool, + device_type: str, + n_dense: int, + n_sparse: int, + dim_dense: int, + dim_sparse: int, + batch_size: int, + n_groups: int, + profile: str, +) -> None: + if cuda_matrix: + n_denses = [64, 128, 256, 512, 1024] + n_sparses = [16, 32, 64, 128, 256] + batch_sizes = [512, 1024, 2048, 4096] + device_types = ["cuda"] + else: + n_denses = [n_dense] + n_sparses = [n_sparse] + batch_sizes = [batch_size] + device_types = [device_type] + + for device_type in device_types: + for batch_size in batch_sizes: + for duplicates in [False, True]: + for n_dense, n_sparse in zip(n_denses, n_sparses): + dup = "_dup" if duplicates else "" + device = torch.device(device_type) + kts = build_kts( + n_dense, + n_sparse, + dim_dense, + dim_sparse, + batch_size, + device, + run_backward, + ) + labels = torch.randint( + 0, 1, (batch_size,), device=torch.device(device_type) + ).float() + groups = build_groups(kts, n_groups, duplicates=duplicates) + bench( + "[pytorch generic] fallback" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _regroup_keyed_tensors, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "[Prod] KeyedTensor.regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KeyedTensor.regroup, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "[Module] KTRegroupAsDict" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KTRegroupAsDict( + groups=groups, keys=[str(i) for i in range(n_groups)] + ), + {"keyed_tensors": kts}, + profile, + ) + bench( + "[2 Ops] permute_multi_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + permute_multi_embedding, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "[1 Op] KT_regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + regroup_kts, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + if not duplicates: + bench( + "[Old Prod] permute_pooled_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _fbgemm_permute_pooled_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + + +if __name__ == "__main__": + main() diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py new file mode 100644 index 000000000..7513195f3 --- /dev/null +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +import sys + +import click + +# Need this for PT2 compile +# Otherwise will get error +# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator. +from fbgemm_gpu import sparse_ops # noqa: F401, E402 +from torchrec.sparse.tests.keyed_jagged_tensor_benchmark_lib import ( + bench, + DEFTAULT_BENCHMARK_FUNCS, + TransformType, +) + +logger: logging.Logger = logging.getLogger(__name__) +logging.basicConfig(format="%(message)s", stream=sys.stdout) +logger.setLevel(logging.DEBUG) + + +@click.command() +@click.option( + "--num-repeat", + default=20, + help="Number of times method under test is run", +) +@click.option( + "--num-warmup", + default=10, + help="Number of times method under test is run for warmup", +) +@click.option( + "--num-features", + default=128, + help="Total number of sparse features per KJT", +) +@click.option( + "--batch-size", + default=4096, + help="Batch size per KJT (assumes non-VBE)", +) +@click.option( + "--mean-pooling-factor", + default=100, + help="Avg pooling factor for KJT", +) +@click.option( + "--num-workers", + default=4, + help="World size to simulate for dist_init", +) +@click.option( + "--test-pt2/--no-test-pt2", + type=bool, + default=False, + help="Whether to benchmark PT2 Eager", +) +@click.option( + "--kjt-funcs", + type=str, + default=",".join(DEFTAULT_BENCHMARK_FUNCS), + help="kjt functions to benchmark", +) +# pyre-ignore [56] +@click.option( + "--run-modes", + type=str, + default=",".join([member.name for member in TransformType]), + help="kjt functions to benchmark", +) +def main( + num_repeat: int, + num_warmup: int, + num_features: int, + batch_size: int, + mean_pooling_factor: int, + num_workers: int, + test_pt2: bool, + kjt_funcs: str, + run_modes: str, +) -> None: + bench( + num_repeat, + num_warmup, + num_features, + batch_size, + mean_pooling_factor, + num_workers, + test_pt2, + kjt_funcs, + run_modes, + ) + + +if __name__ == "__main__": + main() diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py new file mode 100644 index 000000000..1c409fcf2 --- /dev/null +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -0,0 +1,522 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +import random +import sys +import time +import timeit +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +# Need this for PT2 compile +# Otherwise will get error +# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator. +from fbgemm_gpu import sparse_ops # noqa: F401, E402 +from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult, MemoryStats +from torchrec.distributed.dist_data import _get_recat + +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) +logging.basicConfig(format="%(message)s", stream=sys.stdout) +logger.setLevel(logging.DEBUG) + +DEFTAULT_BENCHMARK_FUNCS = [ + "permute", + "to_dict", + "split", + "concat", + "__getitem__", + "dist_splits", + "dist_init", +] + + +class TransformType(Enum): + DEFAULT = "DEFAULT" + JIT_SCRIPT = "JIT_SCRIPT" + VBE = "VBE" + PT2_AOT_EAGER = "AOT_EAGER" + + +class BenchmarkFormatter: + TABLE_HEADERS: List[List[str]] = [ + [ + "Method Name", + "Transform", + "Variable Batch", + "Batch Size", + "# Features", + "Avg Pooling Factor", + "Runtime (P50)", + "Runtime (P90)", + ], + [ + "-------------", + "-------------", + "----------------", + "------------", + "------------", + "--------------------", + "------------------------------", + "------------------------------", + ], + ] + + def __init__( + self, + batch_size: int, + num_features: int, + mean_pooling_factor: int, + ) -> None: + self._batch_size = batch_size + self._num_features = num_features + self._mean_pooling_factor = mean_pooling_factor + self._delimiter = "|" + # map method name -> (p50 runtime, p90 runtime) in ms + self._runtime_baseline: Dict[str, Tuple[float, float]] = {} + self._divider_widths: List[int] = [ + len(divider) for divider in BenchmarkFormatter.TABLE_HEADERS[1] + ] + + def set_baseline( + self, method_name: str, p50_runtime: float, p90_runtime: float + ) -> None: + self._runtime_baseline[method_name] = (p50_runtime, p90_runtime) + + def print_headers(self) -> None: + row_format = "|".join( + [" {:<" + str(w - 2) + "} " for w in self._divider_widths] + ) + # headers + logger.info(row_format.format(*self.TABLE_HEADERS[0])) + # dividers + logger.info("+".join(self.TABLE_HEADERS[1])) + + def format_width(self, s: str, col_idx: int) -> str: + return f"{s:<{self._divider_widths[col_idx]-2}}" + + def get_runtime_delta(self, duration: float, baseline: float) -> str: + if duration <= baseline: + delta_pct = (baseline - duration) / duration + direction = "faster" + else: + delta_pct = (duration - baseline) / baseline + direction = "slower" + return f"{delta_pct * 100.0:.1f}% {direction}" + + def format_runtime( + self, duration: float, baseline_duration: float, is_baseline: bool + ) -> str: + return f"{duration * 1000:<8.3g} ms ({'baseline' if is_baseline else self.get_runtime_delta(duration, baseline_duration)})" + + def print_formatted( + self, + method_name: str, + transform_type: TransformType, + is_vb: bool, + p50_runtime: float, + p90_runtime: float, + is_baseline: bool, + ) -> None: + cols = [ + method_name, + transform_type.value, + "Yes" if is_vb else "No", + self._batch_size, + self._num_features, + self._mean_pooling_factor, + self.format_runtime( + p50_runtime, self._runtime_baseline[method_name][0], is_baseline + ), + self.format_runtime( + p90_runtime, self._runtime_baseline[method_name][1], is_baseline + ), + ] + + row_format = "|".join( + [" {:<" + str(w - 2) + "} " for w in self._divider_widths] + ) + logger.info(row_format.format(*cols)) + + +def generate_kjt( + tables: List[EmbeddingBagConfig], + batch_size: int, + mean_pooling_factor: int, + device: torch.device, +) -> KeyedJaggedTensor: + global_input = ModelInput.generate( + batch_size=batch_size, + world_size=1, # 1 for cpu + num_float_features=0, + tables=tables, + weighted_tables=[], + # mean pooling factor per feature + tables_pooling=[mean_pooling_factor] * len(tables), + # returns KJTs with values all set to 0 + # we don't care about KJT values for benchmark, and this saves time + randomize_indices=True, + device=device, + )[0] + assert isinstance(global_input.idlist_features, KeyedJaggedTensor) + return global_input.idlist_features + + +def build_kjts( + tables: List[EmbeddingBagConfig], + batch_size: int, + mean_pooling_factor: int, + device: torch.device, +) -> KeyedJaggedTensor: + start = time.perf_counter() + logger.info("Starting to build KJTs") + + kjt = generate_kjt( + tables, + batch_size, + mean_pooling_factor, + device, + ) + + end = time.perf_counter() + time_taken_s = end - start + logger.info(f"Took {time_taken_s * 1000:.1f}ms to build KJT\n") + return kjt + + +def benchmark_kjt( + test_name: str, + # pyre-ignore[2] + test_module: Union[torch.nn.Module, Callable[..., Any]], + kjt: KeyedJaggedTensor, + num_repeat: int, + num_warmup: int, + bench_formatter: BenchmarkFormatter, + fn_kwargs: Dict[str, Any], + transform_type: TransformType, + is_vb: bool = False, + is_baseline: bool = False, + print_formatted: bool = True, +) -> BenchmarkResult: + for _ in range(num_warmup): + # Reset cached states + kjt.unsync() + kjt._jt_dict = None + test_module(**fn_kwargs) + + times = [] + for _ in range(num_repeat): + # Reset cached states + kjt.unsync() + kjt._jt_dict = None + + time_elapsed = timeit.timeit(lambda: test_module(**fn_kwargs), number=1) + # remove length_per_key and offset_per_key cache for fairer comparison + times.append(time_elapsed) + + result = BenchmarkResult( + short_name=f"{test_name}-{transform_type.name}", + elapsed_time=torch.tensor(times), + mem_stats=[MemoryStats(0, 0, 0, 0)], + ) + + p50_runtime = result.runtime_percentile(50, interpolation="linear").item() + p90_runtime = result.runtime_percentile(90, interpolation="linear").item() + + if is_baseline: + bench_formatter.set_baseline(test_name, p50_runtime, p90_runtime) + + if print_formatted: + bench_formatter.print_formatted( + method_name=test_name, + transform_type=transform_type, + is_vb=is_vb, + p50_runtime=p50_runtime, + p90_runtime=p90_runtime, + is_baseline=is_baseline, + ) + + return result + + +def get_k_splits(n: int, k: int) -> List[int]: + split_size, _ = divmod(n, k) + splits = [split_size] * (k - 1) + [n - split_size * (k - 1)] + return splits + + +def gen_dist_split_input( + tables: List[EmbeddingBagConfig], + batch_size: int, + num_workers: int, + num_features: int, + mean_pooling_factor: int, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], Optional[torch.Tensor]]: + batch_size_per_rank = get_k_splits(n=batch_size, k=num_workers) + kjts = [ + generate_kjt(tables, batch_size_rank, mean_pooling_factor, device) + for batch_size_rank in batch_size_per_rank + ] + kjt_lengths = torch.cat([kjt.lengths() for kjt in kjts]) + kjt_values = torch.cat([kjt.values() for kjt in kjts]) + recat = _get_recat( + local_split=num_features, + num_splits=num_workers, + device=device, + batch_size_per_rank=batch_size_per_rank, + ) + + return (kjt_lengths, kjt_values, batch_size_per_rank, recat) + + +class KJTPermute(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, indices: List[int]) -> KeyedJaggedTensor: + return kjt.permute(indices) + + +class KJTToDict(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + return kjt.to_dict() + + +class KJTSplit(torch.nn.Module): + def forward( + self, kjt: KeyedJaggedTensor, segments: List[int] + ) -> List[KeyedJaggedTensor]: + return kjt.split(segments) + + +class KJTConcat(torch.nn.Module): + def forward( + self, + inputs: List[KeyedJaggedTensor], + ) -> KeyedJaggedTensor: + return KeyedJaggedTensor.concat(inputs) + + +class KJTGetItem(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, key: str) -> JaggedTensor: + return kjt[key] + + +class KJTDistSplits(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, key_splits: List[int]) -> List[List[int]]: + return kjt.dist_splits(key_splits) + + +class KJTDistInit(torch.nn.Module): + def forward( + self, + keys: List[str], + tensors: List[torch.Tensor], + variable_stride_per_key: bool, + num_workers: int, + recat: Optional[torch.Tensor], + stride_per_rank: Optional[List[int]], + ) -> KeyedJaggedTensor: + return KeyedJaggedTensor.dist_init( + keys, tensors, variable_stride_per_key, num_workers, recat, stride_per_rank + ) + + +# pyre-ignore +def dynamo_compile( + method_name: str, + kjt_module: torch.nn.Module, + backend: str, + fullgraph: bool, + fn_kwargs: Dict[str, Any], +) -> Callable[..., Any]: + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + compiled_mod = torch.compile(kjt_module, backend=backend, fullgraph=fullgraph) + return compiled_mod + + +def bench( + num_repeat: int, + num_warmup: int, + num_features: int, + batch_size: int, + mean_pooling_factor: int, + num_workers: int, + test_pt2: bool, + kjt_funcs: str, + run_modes: str, +) -> List[BenchmarkResult]: + # TODO: support CUDA benchmark + device: torch.device = torch.device("cpu") + + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=20, # determines indices range + embedding_dim=10, # doesn't matter for benchmark + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(num_features) + ] + + kjt_funcs_to_benchmark: List[str] = kjt_funcs.split(",") + run_modes_to_benchmark: List[str] = run_modes.split(",") + + kjt = build_kjts( + tables, + batch_size, + mean_pooling_factor, + device, + ) + + splits = get_k_splits(n=num_features, k=8) + + permute_indices = random.sample(range(num_features), k=num_features) + + key = f"feature_{random.randint(0, num_features - 1)}" + + kjt_lengths, kjt_values, strides_per_rank, recat = gen_dist_split_input( + tables, batch_size, num_workers, num_features, mean_pooling_factor, device + ) + + benchmarked_methods: List[Tuple[str, Dict[str, Any], torch.nn.Module]] = [ + ("permute", {"kjt": kjt, "indices": permute_indices}, KJTPermute()), + ("to_dict", {"kjt": kjt}, KJTToDict()), + ("split", {"kjt": kjt, "segments": splits}, KJTSplit()), + ("concat", {"inputs": [kjt, kjt]}, KJTConcat()), + ("__getitem__", {"kjt": kjt, "key": key}, KJTGetItem()), + ("dist_splits", {"kjt": kjt, "key_splits": splits}, KJTDistSplits()), + ( + "dist_init", + { + "keys": kjt.keys(), + "tensors": [ + # lengths from each rank, should add up to num_features x batch_size in total + kjt_lengths, + # values from each rank + kjt_values, + ], + "variable_stride_per_key": False, + "num_workers": num_workers, + "recat": recat, + "stride_per_rank": strides_per_rank, + }, + KJTDistInit(), + ), + ] + + bench_formatter = BenchmarkFormatter( + batch_size, + num_features, + mean_pooling_factor, + ) + bench_formatter.print_headers() + + all_results: List[BenchmarkResult] = [] + + filtered_benchmarked_methods = [ + row for row in benchmarked_methods if row[0] in kjt_funcs_to_benchmark + ] + + for method_name, fn_kwargs, kjt_module in filtered_benchmarked_methods: + if "DEFAULT" in run_modes_to_benchmark: + # Test Eager + result = benchmark_kjt( + test_name=method_name, + kjt=kjt, + test_module=kjt_module, + num_repeat=num_repeat, + num_warmup=num_warmup, + fn_kwargs=fn_kwargs, + transform_type=TransformType.DEFAULT, + bench_formatter=bench_formatter, + is_baseline=True, + ) + all_results.append(result) + + if "JIT_SCRIPT" in run_modes_to_benchmark: + # Test JIT script + result = benchmark_kjt( + test_name=method_name, + kjt=kjt, + test_module=torch.jit.script(kjt_module), + num_repeat=num_repeat, + num_warmup=num_warmup, + fn_kwargs=fn_kwargs, + transform_type=TransformType.JIT_SCRIPT, + bench_formatter=bench_formatter, + ) + + all_results.append(result) + + if "VBE" in run_modes_to_benchmark: + # Test Eager VBE + vbe_kjt = KeyedJaggedTensor( + keys=kjt.keys(), + values=kjt._values, + lengths=kjt._lengths, + stride_per_key_per_rank=kjt._stride_per_key_per_rank, + ) + vbe_fn_kwargs = fn_kwargs.copy() + if "kjt" in fn_kwargs: + vbe_fn_kwargs["kjt"] = vbe_kjt + + result = benchmark_kjt( + test_name=method_name, + kjt=vbe_kjt, + test_module=kjt_module, + num_repeat=num_repeat, + num_warmup=num_warmup, + fn_kwargs=vbe_fn_kwargs, + transform_type=TransformType.VBE, + is_vb=True, + bench_formatter=bench_formatter, + ) + all_results.append(result) + + # PT2 (Eager Inductor) + if test_pt2 or "AOT_EAGER" in run_modes_to_benchmark: + vbe_kjt = KeyedJaggedTensor( + keys=kjt.keys(), + values=kjt._values, + lengths=kjt._lengths, + stride_per_key_per_rank=kjt._stride_per_key_per_rank, + ) + vbe_fn_kwargs = fn_kwargs.copy() + if "kjt" in fn_kwargs: + vbe_fn_kwargs["kjt"] = vbe_kjt + dynamo_compiled_mod = dynamo_compile( + method_name, + kjt_module, + backend="aot_eager", + fullgraph=True, + fn_kwargs=vbe_fn_kwargs, + ) + + result = benchmark_kjt( + test_name=method_name, + kjt=vbe_kjt, + test_module=dynamo_compiled_mod, + num_repeat=num_repeat, + num_warmup=num_warmup, + # simulate VBE, otherwise torch.compile currently fails + fn_kwargs=vbe_fn_kwargs, + transform_type=TransformType.PT2_AOT_EAGER, + is_vb=True, + bench_formatter=bench_formatter, + ) + + all_results.append(result) + # Leave a gap between methods + print("") + return all_results diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 7d5c72876..1f15cbeaf 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -5,24 +5,125 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from typing import List, Tuple import torch +import torch.utils._pytree as pytree from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( - ComputeKJTToJTDict, + ComputeJTDictToKJT, JaggedTensor, + jt_is_equal, KeyedJaggedTensor, - KeyedTensor, ) torch.fx.wrap("len") class TestJaggedTensor(unittest.TestCase): + def test_equality(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + lengths = torch.IntTensor([1, 0, 2, 3]) + weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) + """ + JaggedTensor representation from above values + # [[1.0], [], [2.0, 3.0], [4.0, 5.0, 6.0]] + """ + # JT equality, from different construction methods + jt = JaggedTensor(values=values, lengths=lengths) + dense_values = torch.Tensor( + [[1.0, 11.0, 12.0], [9.0, 23.0, 11.0], [2.0, 3.0, 55.0], [4.0, 5.0, 6.0]] + ) + jt_1 = JaggedTensor.from_dense_lengths( + values=dense_values, lengths=torch.IntTensor([1, 0, 2, 3]) + ) + self.assertTrue(jt_is_equal(jt, jt_1)) + + # Different values + jt = JaggedTensor( + values=torch.Tensor([2.0, 10.0, 11.0, 42.0, 3.0, 99.0]), lengths=lengths + ) + self.assertFalse(jt_is_equal(jt, jt_1)) + + # Different lengths + jt = JaggedTensor(values=values, lengths=torch.IntTensor([1, 1, 0, 4])) + self.assertFalse(jt_is_equal(jt, jt_1)) + + # Including weights + """ + # values: [[1.0], [], [2.0, 3.0], [4.0, 5.0, 6.0]] + # weights: [[0.1], [], [0.2, 0.3], [0.4, 0.5 ,0.6]] + """ + jt = JaggedTensor(values=values, lengths=lengths, weights=weights) + + dense_weights = torch.Tensor( + [[0.1, 1.1, 1.2], [0.9, 2.3, 1.1], [0.2, 0.3, 5.5], [0.4, 0.5, 0.6]] + ) + jt_1 = JaggedTensor.from_dense_lengths( + values=dense_values, + lengths=torch.IntTensor([1, 0, 2, 3]), + weights=dense_weights, + ) + + self.assertTrue(jt_is_equal(jt, jt_1)) + + # Different weights + jt = JaggedTensor( + values=values, + lengths=lengths, + weights=torch.Tensor([1.4, 0.2, 3.2, 0.4, 42.0, 0.6]), + ) + self.assertFalse(jt_is_equal(jt, jt_1)) + + # from dense, equal lengths + values_for_dense = [ + torch.Tensor([1.0]), + torch.Tensor(), + torch.Tensor([2.0, 3.0]), + torch.Tensor([4.0, 5.0, 6.0]), + ] + weights_for_dense = [ + torch.Tensor([0.1]), + torch.Tensor(), + torch.Tensor([0.2, 0.3]), + torch.Tensor([0.4, 0.5, 0.6]), + ] + + jt = JaggedTensor.from_dense( + values=values_for_dense, + weights=weights_for_dense, + ) + + self.assertTrue(jt_is_equal(jt, jt_1)) + + # from dense, unequal lengths + values_for_dense = [ + torch.Tensor([1.0]), + torch.Tensor([3.0, 10.0, 42.0]), + torch.Tensor([2.0, 3.0]), + torch.Tensor([4.0, 5.0, 6.0]), + ] + weights_for_dense = [ + torch.Tensor([0.1]), + torch.Tensor([0.3, 1.1, 4.2]), + torch.Tensor([0.2, 0.3]), + torch.Tensor([0.4, 0.5, 0.6]), + ] + + jt = JaggedTensor.from_dense( + values=values_for_dense, + weights=weights_for_dense, + ) + self.assertFalse(jt_is_equal(jt, jt_1)) + + # wrong type + jt = "not a jagged tensor" + self.assertFalse(jt_is_equal(jt, jt_1)) + def test_str(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) j_1d = JaggedTensor( @@ -154,6 +255,26 @@ def test_from_dense(self) -> None: torch.equal(j1.weights(), torch.Tensor([1.0, 7.0, 8.0, 10.0, 11.0, 12.0])) ) + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_from_dense_device(self) -> None: + device = torch.device("cuda", index=0) + values = [ + torch.tensor([1.0], device=device), + torch.tensor([7.0, 8.0], device=device), + torch.tensor([10.0, 11.0, 12.0], device=device), + ] + + j0 = JaggedTensor.from_dense( + values=values, + ) + self.assertEqual(j0.values().device, device) + self.assertEqual(j0.lengths().device, device) + self.assertEqual(j0.offsets().device, device) + def test_to_dense(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) @@ -173,9 +294,37 @@ def test_to_dense(self) -> None: for t0, expected_t0 in zip(torch_list, expected_list): self.assertTrue(torch.equal(t0, expected_t0)) + def test_to_dense_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + jt = JaggedTensor( + values=values, + weights=weights, + offsets=offsets, + ) + weights_list = jt.to_dense_weights() + expected_weights_list = [ + torch.tensor([0.1, 0.2]), + torch.tensor([]), + torch.tensor([0.3]), + torch.tensor([0.4]), + torch.tensor([0.5]), + torch.tensor([0.6, 0.7, 0.8]), + ] + for t0, expected_t0 in zip(weights_list, expected_weights_list): + self.assertTrue(torch.equal(t0, expected_t0)) + + jt = JaggedTensor( + values=values, + offsets=offsets, + ) + weights_list = jt.to_dense_weights() + self.assertIsNone(weights_list) + def test_to_padded_dense(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).type( - torch.float64 + torch.float32 ) offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) jt = JaggedTensor( @@ -183,7 +332,7 @@ def test_to_padded_dense(self) -> None: offsets=offsets, ) t0 = jt.to_padded_dense() - self.assertEqual(t0.dtype, torch.float64) + self.assertEqual(t0.dtype, torch.float32) t0_value = [ [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], @@ -192,11 +341,11 @@ def test_to_padded_dense(self) -> None: [5.0, 0.0, 0.0], [6.0, 7.0, 8.0], ] - expected_t0 = torch.tensor(t0_value).type(torch.float64) + expected_t0 = torch.tensor(t0_value).type(torch.float32) self.assertTrue(torch.equal(t0, expected_t0)) t1 = jt.to_padded_dense(desired_length=2, padding_value=10.0) - self.assertEqual(t1.dtype, torch.float64) + self.assertEqual(t1.dtype, torch.float32) t1_value = [ [1.0, 2.0], [10.0, 10.0], @@ -205,7 +354,7 @@ def test_to_padded_dense(self) -> None: [5.0, 10.0], [6.0, 7.0], ] - expected_t1 = torch.tensor(t1_value).type(torch.float64) + expected_t1 = torch.tensor(t1_value).type(torch.float32) self.assertTrue(torch.equal(t1, expected_t1)) values = torch.Tensor( @@ -247,6 +396,96 @@ def test_to_padded_dense(self) -> None: expected_t2 = torch.tensor(t2_value).type(torch.int64) self.assertTrue(torch.equal(t2, expected_t2)) + def test_to_padded_dense_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).type( + torch.float64 + ) + weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + jt = JaggedTensor( + values=values, + weights=weights, + offsets=offsets, + ) + t0_weights = jt.to_padded_dense_weights() + expected_t0_weights = [ + [0.1, 0.2, 0.0], + [0.0, 0.0, 0.0], + [0.3, 0.0, 0.0], + [0.4, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.6, 0.7, 0.8], + ] + + expected_t0_weights = torch.tensor(expected_t0_weights) + self.assertTrue(torch.equal(t0_weights, expected_t0_weights)) + + t1_weights = jt.to_padded_dense_weights(desired_length=2, padding_value=1.0) + expected_t1_weights = [ + [0.1, 0.2], + [1.0, 1.0], + [0.3, 1.0], + [0.4, 1.0], + [0.5, 1.0], + [0.6, 0.7], + ] + expected_t1_weights = torch.tensor(expected_t1_weights) + self.assertTrue(torch.equal(t1_weights, expected_t1_weights)) + + values = torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0], + ] + ).type(torch.int64) + weights = torch.Tensor( + [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + [1.0, 1.1, 1.2], + ] + ) + jt = JaggedTensor( + values=values, + weights=weights, + lengths=torch.IntTensor([1, 0, 2, 0]), + ) + t2_weights = jt.to_padded_dense_weights(desired_length=3) + expected_t2_weights = [ + [ + [0.1, 0.2, 0.3], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + [0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ] + expected_t2_weights = torch.tensor(expected_t2_weights) + self.assertTrue(torch.equal(t2_weights, expected_t2_weights)) + + jt = JaggedTensor( + values=values, + lengths=torch.IntTensor([1, 0, 2, 0]), + ) + t3_weights = jt.to_padded_dense_weights(desired_length=3) + self.assertIsNone(t3_weights) + def test_key_lookup(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) keys = ["index_0", "index_1"] @@ -312,11 +551,31 @@ def test_length_vs_offset(self) -> None: # TODO: T88149179 self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) + stride_per_key_per_rank = [[3], [5]] + j_offset = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + j_lens = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) + self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) + def test_empty(self) -> None: - jt = JaggedTensor.empty() + jt = JaggedTensor.empty(values_dtype=torch.int64) - self.assertTrue(torch.equal(jt.values(), torch.tensor([]))) - self.assertTrue(torch.equal(jt.offsets(), torch.tensor([]))) + self.assertTrue(torch.equal(jt.values(), torch.tensor([], dtype=torch.int64))) + self.assertTrue(torch.equal(jt.offsets(), torch.tensor([], dtype=torch.int32))) + + jt_from_script = torch.jit.script(JaggedTensor.empty)() + self.assertEqual(jt_from_script.to_dense(), []) def test_2d(self) -> None: values = torch.Tensor([[i * 0.5, i * 1.0, i * 1.5] for i in range(1, 4)]) @@ -387,11 +646,7 @@ def test_string_basic(self) -> None: self.assertEqual( str(jag_tensor), - """\ -JaggedTensor({ - [[1.0]] -}) -""", + """JaggedTensor({\n [[1.0]]\n})\n""", ) def test_string_values(self) -> None: @@ -405,11 +660,8 @@ def test_string_values(self) -> None: self.assertEqual( str(jag_tensor), - """\ -JaggedTensor({ - [[1.0, 2.0], [], [3.0], [4.0], [5.0], [6.0, 7.0, 8.0]] -}) -""", + "JaggedTensor({\n [[1.0, 2.0], [], [3.0], " + "[4.0], [5.0], [6.0, 7.0, 8.0]]\n})\n", ) def test_string_weights(self) -> None: @@ -425,12 +677,139 @@ def test_string_weights(self) -> None: self.assertEqual( str(jag_tensor), - """\ -JaggedTensor({ - "values": [[1.0, 2.0], [], [3.0], [4.0], [5.0], [6.0, 7.0, 8.0]], - "weights": [[1.0, 0.5], [], [1.5], [1.0], [0.5], [1.0, 1.0, 1.5]] -}) -""", + 'JaggedTensor({\n "values": [[1.0, 2.0], [], [3.0], ' + '[4.0], [5.0], [6.0, 7.0, 8.0]],\n "weights": ' + "[[1.0, 0.5], [], [1.5], [1.0], [0.5], [1.0, 1.0, 1.5]]\n})\n", + ) + + def test_pytree(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + j0 = JaggedTensor( + values=values, + lengths=torch.IntTensor([1, 0, 2, 3]), + ) + elems, spec = pytree.tree_flatten(j0) + j1 = pytree.tree_unflatten(elems, spec) + + self.assertTrue(torch.equal(j0.lengths(), j1.lengths())) + self.assertIsNone(j0.weights_or_none()) + self.assertIsNone(j1.weights_or_none()) + self.assertTrue(torch.equal(j0.values(), j1.values())) + + values = [ + torch.Tensor([1.0]), + torch.Tensor(), + torch.Tensor([7.0, 8.0]), + torch.Tensor([10.0, 11.0, 12.0]), + ] + weights = [ + torch.Tensor([1.0]), + torch.Tensor(), + torch.Tensor([7.0, 8.0]), + torch.Tensor([10.0, 11.0, 12.0]), + ] + j0 = JaggedTensor.from_dense( + values=values, + weights=weights, + ) + elems, spec = pytree.tree_flatten(j0) + j1 = pytree.tree_unflatten(elems, spec) + + self.assertTrue(torch.equal(j0.lengths(), j1.lengths())) + self.assertTrue(torch.equal(j0.weights(), j1.weights())) + self.assertTrue(torch.equal(j0.values(), j1.values())) + + def test_compute_jt_dict_to_kjt_module(self) -> None: + compute_jt_dict_to_kjt = ComputeJTDictToKJT() + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + jag_tensor_dict = jag_tensor.to_dict() + kjt = compute_jt_dict_to_kjt(jag_tensor_dict) + j0 = kjt["index_0"] + j1 = kjt["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_from_jt_dict(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + jag_tensor_dict = jag_tensor.to_dict() + kjt = KeyedJaggedTensor.from_jt_dict(jag_tensor_dict) + j0 = kjt["index_0"] + j1 = kjt["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_from_jt_dict_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + stride_per_key_per_rank = [[2], [4]] + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + jag_tensor_dict = jag_tensor.to_dict() + kjt = KeyedJaggedTensor.from_jt_dict(jag_tensor_dict) + j0 = kjt["index_0"] + j1 = kjt["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.5, 1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) ) @@ -542,949 +921,3 @@ def forward( ref_out = m(8) traced_out = gm(8) self.assertEqual(ref_out, traced_out) - - -class TestKeyedJaggedTensor(unittest.TestCase): - def test_key_lookup(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - j0 = jag_tensor["index_0"] - j1 = jag_tensor["index_1"] - - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) - self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) - self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) - self.assertTrue( - torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) - ) - self.assertTrue( - torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) - ) - - def test_to_dict(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - jag_tensor_dict = jag_tensor.to_dict() - j0 = jag_tensor_dict["index_0"] - j1 = jag_tensor_dict["index_1"] - - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) - self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) - self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) - self.assertTrue( - torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) - ) - self.assertTrue( - torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) - ) - - def test_empty(self) -> None: - keys = ["index_0"] - values = torch.tensor([]) - lengths = torch.tensor([]) - offsets = torch.tensor([]) - - kjt_0 = KeyedJaggedTensor(keys=keys, values=values, lengths=lengths) - j0 = kjt_0["index_0"] - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) - - keys = ["index_1"] - kjt_1 = KeyedJaggedTensor(keys=keys, values=values, offsets=offsets) - j1 = kjt_1["index_1"] - - self.assertTrue(isinstance(j1, JaggedTensor)) - self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) - - combined_kjt = KeyedJaggedTensor.concat([kjt_0, kjt_1]) - j0 = combined_kjt["index_0"] - j1 = combined_kjt["index_1"] - - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) - self.assertTrue(isinstance(j1, JaggedTensor)) - self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) - - kjt_2 = KeyedJaggedTensor.empty() - self.assertEqual(kjt_2.to_dict(), {}) - - def test_empty_to_dict(self) -> None: - keys = ["index_0", "index_1"] - values = torch.tensor([]) - lengths = torch.tensor([[], []]) - length_per_key = [0, 0] - - jag_tensor = KeyedJaggedTensor( - keys=keys, values=values, lengths=lengths, length_per_key=length_per_key - ) - jag_tensor_dict = jag_tensor.to_dict() - j0 = jag_tensor_dict["index_0"] - j1 = jag_tensor_dict["index_1"] - - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.offsets(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) - self.assertTrue(isinstance(j1, JaggedTensor)) - self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.offsets(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) - - jag_tensor = KeyedJaggedTensor.from_lengths_sync( - keys=keys, values=values, lengths=lengths - ) - jag_tensor_dict = jag_tensor.to_dict() - j0 = jag_tensor_dict["index_0"] - j1 = jag_tensor_dict["index_1"] - - self.assertTrue(isinstance(j0, JaggedTensor)) - self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.offsets(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) - self.assertTrue(isinstance(j1, JaggedTensor)) - self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.offsets(), torch.Tensor([]))) - self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) - - def test_split(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - j0, j1 = jag_tensor.split([1, 1]) - - self.assertTrue(isinstance(j0, KeyedJaggedTensor)) - self.assertEqual(j0.keys(), ["index_0"]) - self.assertEqual(j1.keys(), ["index_1"]) - self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) - self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) - self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) - self.assertTrue( - torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) - ) - self.assertTrue( - torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) - ) - - def test_zero_split(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - j0, j1 = jag_tensor.split([0, 2]) - - self.assertTrue(isinstance(j0, KeyedJaggedTensor)) - self.assertEqual(j0.keys(), []) - self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([]))) - self.assertTrue(torch.equal(j0.weights(), torch.Tensor([]))) - self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) - self.assertEqual(j0.stride(), 3) - - self.assertEqual(j1.keys(), ["index_0", "index_1"]) - self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1, 1, 3]))) - self.assertTrue(torch.equal(j1.weights(), weights)) - self.assertTrue(torch.equal(j1.values(), values)) - self.assertEqual(j0.stride(), 3) - - def test_permute_w_weights(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) - keys = ["index_0", "index_1", "index_2"] - - jag_tensor = KeyedJaggedTensor.from_lengths_sync( - values=values, - keys=keys, - lengths=lengths, - weights=weights, - ) - - indices = [1, 0, 2] - permuted_jag_tensor = jag_tensor.permute(indices) - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) - self.assertEqual( - permuted_jag_tensor.offset_per_key(), - [0, 3, 5, 8], - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.values(), - torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), - ) - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.lengths(), - torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), - ) - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.weights(), - torch.Tensor([1.5, 1.0, 0.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - ), - ) - - def test_permute(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) - keys = ["index_0", "index_1", "index_2"] - - jag_tensor = KeyedJaggedTensor.from_lengths_sync( - values=values, - keys=keys, - lengths=lengths, - ) - - indices = [1, 0, 2] - permuted_jag_tensor = jag_tensor.permute(indices) - - self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) - self.assertEqual( - permuted_jag_tensor.offset_per_key(), - [0, 3, 5, 8], - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.values(), - torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), - ) - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.lengths(), - torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), - ) - ) - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) - - def test_permute_duplicates(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) - keys = ["index_0", "index_1", "index_2"] - - jag_tensor = KeyedJaggedTensor.from_lengths_sync( - values=values, - keys=keys, - lengths=lengths, - ) - - indices = [1, 0, 2, 1, 1] - permuted_jag_tensor = jag_tensor.permute(indices) - - self.assertEqual( - permuted_jag_tensor.keys(), - ["index_1", "index_0", "index_2", "index_1", "index_1"], - ) - self.assertEqual( - permuted_jag_tensor.offset_per_key(), - [0, 3, 5, 8, 11, 14], - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.values(), - torch.Tensor( - [ - 3.0, - 4.0, - 5.0, - 1.0, - 2.0, - 6.0, - 7.0, - 8.0, - 3.0, - 4.0, - 5.0, - 3.0, - 4.0, - 5.0, - ] - ), - ) - ) - self.assertTrue( - torch.equal( - permuted_jag_tensor.lengths(), - torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1]), - ) - ) - self.assertEqual(permuted_jag_tensor.weights_or_none(), None) - - def test_concat(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) - keys = ["index_0", "index_1", "index_2"] - lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0, 0, 1, 0]) - - kjt_expected = KeyedJaggedTensor.from_lengths_sync( - values=values, - keys=keys, - lengths=lengths, - ) - kjt_actual = KeyedJaggedTensor.concat( - [ - KeyedJaggedTensor.from_lengths_sync( - values=values[:4], - keys=keys[:1], - lengths=lengths[:4], - ), - KeyedJaggedTensor.from_lengths_sync( - values=values[4:], - keys=keys[1:], - lengths=lengths[4:], - ), - ], - ) - self.assertTrue(torch.equal(kjt_expected.lengths(), kjt_actual.lengths())) - self.assertTrue(torch.equal(kjt_expected.offsets(), kjt_actual.offsets())) - self.assertTrue(torch.equal(kjt_expected.values(), kjt_actual.values())) - # pyre-ignore[6] - self.assertListEqual(kjt_expected._length_per_key, kjt_actual._length_per_key) - - def test_length_vs_offset(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]) - lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]) - - j_offset = KeyedJaggedTensor.from_offsets_sync( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - - j_lens = KeyedJaggedTensor.from_lengths_sync( - values=values, - keys=keys, - lengths=lengths, - weights=weights, - ) - - self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) - # TO DO: T88149179 - self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) - - def test_2d(self) -> None: - values = torch.Tensor([[i * 0.5, i * 1.0, i * 1.5] for i in range(1, 9)]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - j = KeyedJaggedTensor.from_offsets_sync( - values=values, - weights=weights, - keys=keys, - offsets=offsets, - ) - j_0 = j["index_0"] - - self.assertTrue(torch.equal(j_0.lengths(), torch.IntTensor([2, 0, 1]))) - self.assertTrue( - torch.equal( - j_0.values(), - torch.Tensor( - [ - [0.5, 1.0, 1.5], - [1.0, 2.0, 3.0], - [1.5, 3.0, 4.5], - ], - ), - ) - ) - - def test_float_lengths_offsets_throws(self) -> None: - values = torch.rand((7, 3)) - keys = ["f1", "f2"] - # torch.Tensor([3, 4]) also fails - # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd - # param but got `Type[float]`. - lengths = torch.tensor([3, 4], dtype=float) - # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd - # param but got `Type[float]`. - offsets = torch.tensor([0, 3, 7], dtype=float) - - with self.assertRaises(AssertionError): - KeyedJaggedTensor.from_lengths_sync( - keys=keys, values=values, lengths=lengths - ) - with self.assertRaises(AssertionError): - KeyedJaggedTensor.from_offsets_sync( - keys=keys, values=values, offsets=offsets - ) - - def test_scriptable(self) -> None: - class MyModule(torch.nn.Module): - def forward(self, input: KeyedJaggedTensor) -> torch.Tensor: - values = input["any"].values() - return values - - m = MyModule() - torch.jit.script(m) - - def test_to(self) -> None: - j = KeyedJaggedTensor.from_offsets_sync( - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - values=torch.arange(8), - weights=torch.arange(8 * 10), - keys=["index_0", "index_1"], - ) - j2 = j.to(device=torch.device("cpu")) - self.assertTrue(torch.equal(j.offsets(), j2.offsets())) - self.assertTrue(torch.equal(j.lengths(), j2.lengths())) - self.assertTrue(torch.equal(j.values(), j2.values())) - self.assertTrue(torch.equal(j.weights(), j2.weights())) - - def test_string_none(self) -> None: - jag_tensor = KeyedJaggedTensor( - # pyre-fixme[6]: For 1st param expected `List[str]` but got `Tensor`. - torch.Tensor(), - # pyre-fixme[6]: For 2nd param expected `Tensor` but got - # `List[Variable[_T]]`. - [], - ) - - self.assertEqual( - str(jag_tensor), - """\ -KeyedJaggedTensor() -""", - ) - - def test_string_basic(self) -> None: - values = torch.Tensor([1.0]) - keys = ["key"] - offsets = torch.IntTensor([0, 1]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - ) - - self.assertEqual( - str(jag_tensor), - """\ -KeyedJaggedTensor({ - "key": [[1.0]] -}) -""", - ) - - def test_string_values(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - ) - - self.assertEqual( - str(jag_tensor), - """\ -KeyedJaggedTensor({ - "index_0": [[1.0, 2.0], [], [3.0]], - "index_1": [[4.0], [5.0], [6.0, 7.0, 8.0]] -}) -""", - ) - - def test_string_weights(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, - ) - - self.assertEqual( - str(jag_tensor), - """\ -KeyedJaggedTensor({ - "index_0": { - "values": [[1.0, 2.0], [], [3.0]], - "weights": [[1.0, 0.5], [], [1.5]] - }, - "index_1": { - "values": [[4.0], [5.0], [6.0, 7.0, 8.0]], - "weights": [[1.0], [0.5], [1.0, 1.0, 1.5]] - } -}) -""", - ) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_record_stream(self) -> None: - j = KeyedJaggedTensor.from_offsets_sync( - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - values=torch.arange(8), - weights=torch.arange(8 * 10), - keys=["index_0", "index_1"], - ).to(torch.device("cuda")) - j.record_stream(torch.cuda.current_stream()) - - -class TestKeyedJaggedTensorScripting(unittest.TestCase): - def test_scriptable_forward(self) -> None: - class MyModule(torch.nn.Module): - def forward(self, input: KeyedJaggedTensor) -> torch.Tensor: - values = input["any"].values() - return values - - m = MyModule() - torch.jit.script(m) - - def test_scriptable_init(self) -> None: - def create_kjt() -> KeyedJaggedTensor: - return KeyedJaggedTensor.from_offsets_sync( - values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - weights=torch.tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - keys=["index_0", "index_1"], - offsets=torch.tensor([0, 0, 2, 2, 3, 4, 5, 5, 8], dtype=torch.int32), - ) - - # assert that we can script KJT creation - torch.jit.script(create_kjt) - - -class TestKeyedJaggedTensorTracingScripting(unittest.TestCase): - def test_jit_tracable(self) -> None: - # This module will simply go through the constructor of the - # KeyedJaggedTensor to construct it with multiple different batch sizes - class MyModule(torch.nn.Module): - def forward( - self, offsets: torch.Tensor, values: torch.Tensor, weights: torch.Tensor - ) -> torch.Tensor: - j = KeyedJaggedTensor.from_offsets_sync( - offsets=offsets, - values=values, - weights=weights, - keys=["index_0", "index_1"], - ) - return j["index_0"].offsets() - - sample_2 = ( - torch.tensor([0, 2, 2]), - torch.arange(2), - torch.arange(2 * 10), - ) - sample_6 = ( - torch.tensor([0, 2, 2, 3, 4, 6, 8]), - torch.arange(8), - torch.arange(8 * 10), - ) - m = MyModule() - model_eager_traced: torch.jit.ScriptModule = torch.jit.trace( - m, sample_2, strict=False - ) - self.assertTrue( - torch.equal(model_eager_traced(*sample_2), torch.tensor([0, 2])) - ) - self.assertTrue( - torch.equal(model_eager_traced(*sample_6), torch.tensor([0, 2, 2, 3])) - ) - - def test_create_and_access_keyed_jagged_tensor(self) -> None: - class ModuleCreateAndAccessKeyedJaggedTensor(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input: int) -> int: - features = KeyedJaggedTensor.from_offsets_sync( - values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - keys=["index_0", "index_1"], - offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), - ) - return ( - len(features.keys()) - + features.values().numel() - + features.weights().numel() - + features.lengths().numel() - + features.offsets().numel() - ) - - # Case 4: KeyedJaggedTensor is only used within the root module and not as part of - # the root module's input/output interface. - m = ModuleCreateAndAccessKeyedJaggedTensor() - gm = symbolic_trace(m) - FileCheck().check("return 35").check_not("KeyedJaggedTensor").run(gm.code) - ref_out = m(8) - traced_out = gm(8) - self.assertEqual(ref_out, traced_out) - torch.jit.script(gm) - - def test_use_keyed_jagged_tensor_as_input_and_output(self) -> None: - class ModuleUseKeyedJaggedTensorAsInputAndOutput(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, input: KeyedJaggedTensor - ) -> Tuple[KeyedJaggedTensor, int]: - output = KeyedJaggedTensor( - input.keys(), - input.values(), - input.weights(), - lengths=input.lengths(), - offsets=input.offsets(), - ) - return output, output._stride - - # Case 3: KeyedJaggedTensor is used as both an input and an output of the root module. - m = ModuleUseKeyedJaggedTensorAsInputAndOutput() - gm = symbolic_trace(m) - FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check( - "._stride" - ).run(gm.code) - input = KeyedJaggedTensor.from_offsets_sync( - values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - keys=["index_0", "index_1"], - offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), - ) - ref_out = m(input) - traced_out = gm(input) - self.assertEqual(ref_out[1], traced_out[1]) - torch.jit.script(gm) - - def test_use_keyed_jagged_tensor_as_input(self) -> None: - class ModuleUseKeyedJaggedTensorAsInput(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input: KeyedJaggedTensor) -> int: - return ( - len(input.keys()) - + input.values().numel() - + input.weights().numel() - + input.lengths().numel() - + input.offsets().numel() - ) - - # Case 2: KeyedJaggedTensor is only used as an input of the root module. - m = ModuleUseKeyedJaggedTensorAsInput() - gm = symbolic_trace(m) - FileCheck().check("KeyedJaggedTensor").check("keys()").check("len").check( - "values()" - ).check("numel()").run(gm.code) - - input = KeyedJaggedTensor.from_offsets_sync( - values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - keys=["index_0", "index_1"], - offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), - ) - ref_out = m(input) - traced_out = gm(input) - self.assertEqual(ref_out, traced_out) - torch.jit.script(gm) - - def test_use_keyed_jagged_tensor_as_output(self) -> None: - class ModuleUseKeyedJaggedTensorAsOutput(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, - keys: List[str], - values: torch.Tensor, - weights: torch.Tensor, - lengths: torch.Tensor, - ) -> Tuple[KeyedJaggedTensor, int]: - output = KeyedJaggedTensor(keys, values, weights, lengths) - return output, output._stride - - # Case 1: KeyedJaggedTensor is only used as an output of the root module. - m = ModuleUseKeyedJaggedTensorAsOutput() - gm = symbolic_trace(m) - FileCheck().check("KeyedJaggedTensor").check( - "return (keyed_jagged_tensor," - ).run(gm.code) - - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) - keys = ["index_0", "index_1"] - lengths = torch.IntTensor([2, 0, 1, 1, 1, 3]) - - ref_out = m(keys, values, weights, lengths) - traced_out = gm(keys, values, weights, lengths) - - self.assertEqual(ref_out[1], traced_out[1]) - self.assertTrue(torch.equal(traced_out[0].offsets(), ref_out[0].offsets())) - torch.jit.script(gm) - - -class TestKeyedTensor(unittest.TestCase): - def test_key_lookup(self) -> None: - tensor_list = [ - torch.Tensor([[1.0, 1.0]]), - torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), - ] - keys = ["dense_0", "dense_1"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) - self.assertEqual(kt.key_dim(), 0) - - self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) - self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) - - def test_key_lookup_dim_1(self) -> None: - tensor_list = [ - torch.tensor([[1.0, 1.0]]).T, - torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, - ] - keys = ["dense_0", "dense_1"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) - self.assertEqual(kt.key_dim(), 1) - self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) - self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) - - def test_to_dict(self) -> None: - tensor_list = [ - torch.Tensor([[1.0, 1.0]]), - torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), - ] - keys = ["dense_0", "dense_1"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) - self.assertEqual(kt.key_dim(), 0) - - d = kt.to_dict() - for key in keys: - self.assertTrue(torch.equal(kt[key], d[key])) - - def test_to_dict_dim_1(self) -> None: - tensor_list = [ - torch.tensor([[1.0, 1.0]]).T, - torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, - ] - keys = ["dense_0", "dense_1"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) - self.assertEqual(kt.key_dim(), 1) - - d = kt.to_dict() - for key in keys: - self.assertTrue(torch.equal(kt[key], d[key])) - - def test_regroup_single_kt(self) -> None: - tensor_list = [torch.randn(2, 3) for i in range(5)] - key_dim = 1 - keys = ["dense_0", "dense_1", "dense_2", "dense_3", "dense_4"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim) - grouped_tensors = KeyedTensor.regroup( - [kt], [["dense_0", "dense_4"], ["dense_1", "dense_3"], ["dense_2"]] - ) - self.assertTrue( - torch.equal( - grouped_tensors[0], torch.cat([tensor_list[0], tensor_list[4]], key_dim) - ) - ) - self.assertTrue( - torch.equal( - grouped_tensors[1], torch.cat([tensor_list[1], tensor_list[3]], key_dim) - ) - ) - self.assertTrue(torch.equal(grouped_tensors[2], tensor_list[2])) - - def test_regroup_multiple_kt(self) -> None: - key_dim = 1 - tensor_list_1 = [torch.randn(2, 3) for i in range(3)] - keys_1 = ["dense_0", "dense_1", "dense_2"] - kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) - tensor_list_2 = [torch.randn(2, 3) for i in range(2)] - keys_2 = ["sparse_0", "sparse_1"] - kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) - grouped_tensors = KeyedTensor.regroup( - [kt_1, kt_2], [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] - ) - self.assertTrue( - torch.equal( - grouped_tensors[0], - torch.cat( - [tensor_list_1[0], tensor_list_2[1], tensor_list_1[2]], key_dim - ), - ) - ) - self.assertTrue( - torch.equal( - grouped_tensors[1], - torch.cat([tensor_list_1[1], tensor_list_2[0]], key_dim), - ) - ) - - def test_regroup_scriptable(self) -> None: - class MyModule(torch.nn.Module): - def forward( - self, inputs: List[KeyedTensor], groups: List[List[str]] - ) -> List[torch.Tensor]: - return KeyedTensor.regroup(inputs, groups) - - m = MyModule() - torch.jit.script(m) - - def test_regroup_fxable(self) -> None: - class MyModule(torch.nn.Module): - def forward( - self, inputs: List[KeyedTensor], groups: List[List[str]] - ) -> List[torch.Tensor]: - return KeyedTensor.regroup(inputs, groups) - - m = MyModule() - - # input - key_dim = 1 - tensor_list_1 = [torch.randn(2, 3) for i in range(3)] - keys_1 = ["dense_0", "dense_1", "dense_2"] - kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) - tensor_list_2 = [torch.randn(2, 3) for i in range(2)] - keys_2 = ["sparse_0", "sparse_1"] - kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) - inputs = [kt_1, kt_2] - groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] - - # ensure that symbolic tracing works - gm = torch.fx.symbolic_trace(m) - results = m(inputs, groups) - traced_results = gm(inputs, groups) - self.assertEqual(len(results), len(traced_results)) - for result, traced_result in zip(results, traced_results): - self.assertTrue(torch.equal(result, traced_result)) - - def test_scriptable(self) -> None: - class MyModule(torch.nn.Module): - def forward(self, input: KeyedTensor) -> torch.Tensor: - values = input["any"].values() - return values - - m = MyModule() - torch.jit.script(m) - - def test_string_none(self) -> None: - jag_tensor = KeyedTensor( - [], - [], - torch.Tensor(), - ) - - self.assertEqual( - str(jag_tensor), - """\ -KeyedTensor() -""", - ) - - def test_string_basic(self) -> None: - tensor_list = [ - torch.tensor([[1.0]]), - ] - keys = ["key"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=0) - - self.assertEqual( - str(kt), - """\ -KeyedTensor({ - "key": [[1.0]] -}) -""", - ) - - def test_string_values(self) -> None: - tensor_list = [ - torch.tensor([[1.0, 1.0]]).T, - torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, - ] - keys = ["dense_0", "dense_1"] - kt = KeyedTensor.from_tensor_list(keys, tensor_list) - - self.assertEqual( - str(kt), - """\ -KeyedTensor({ - "dense_0": [[1.0], [1.0]], - "dense_1": [[2.0, 3.0], [2.0, 3.0]] -}) -""", - ) - - -class TestComputeKJTToJTDict(unittest.TestCase): - def test_key_lookup(self) -> None: - m = ComputeKJTToJTDict() - input = KeyedJaggedTensor.from_offsets_sync( - values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), - keys=["index_0", "index_1"], - offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), - ) - - out = m(input) - - i0 = out["index_0"] - self.assertTrue(torch.equal(i0._values, torch.tensor([1.0, 2.0, 3.0]))) - self.assertTrue(torch.equal(i0._weights, torch.tensor([1.0, 0.5, 1.5]))) - self.assertTrue(torch.equal(i0._lengths, torch.tensor([0, 2, 0, 1]))) - self.assertTrue(torch.equal(i0._offsets, torch.tensor([0, 0, 2, 2, 3]))) - - i1 = out["index_1"] - self.assertTrue( - torch.equal(i1._values, torch.tensor([4.0, 5.0, 6.0, 7.0, 8.0])) - ) - self.assertTrue( - torch.equal(i1._weights, torch.tensor([1.0, 0.5, 1.0, 1.0, 1.5])) - ) - self.assertTrue(torch.equal(i1._lengths, torch.tensor([1, 1, 0, 3]))) - self.assertTrue(torch.equal(i1._offsets, torch.tensor([0, 1, 2, 2, 5]))) diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py new file mode 100644 index 000000000..1636a06bd --- /dev/null +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -0,0 +1,1524 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from torch.testing import FileCheck +from torchrec.fx import symbolic_trace +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + kjt_is_equal, +) +from torchrec.test_utils import skip_if_asan_class + +torch.fx.wrap("len") + + +class TestKeyedJaggedTensor(unittest.TestCase): + def test_key_lookup(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0 = jag_tensor["index_0"] + j1 = jag_tensor["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_key_lookup_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + stride_per_key_per_rank = [[2], [4]] + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + j0 = jag_tensor["index_0"] + j1 = jag_tensor["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.5, 1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_to_dict(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + jag_tensor_dict = jag_tensor.to_dict() + j0 = jag_tensor_dict["index_0"] + j1 = jag_tensor_dict["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_pytree(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + j0 = JaggedTensor( + values=values, + lengths=torch.IntTensor([1, 0, 2, 3]), + ) + elems, spec = pytree.tree_flatten(j0) + j1 = pytree.tree_unflatten(elems, spec) + + self.assertTrue(torch.equal(j0.lengths(), j1.lengths())) + self.assertIsNone(j0.weights_or_none()) + self.assertIsNone(j1.weights_or_none()) + self.assertTrue(torch.equal(j0.values(), j1.values())) + + values = [ + torch.Tensor([1.0]), + torch.Tensor(), + torch.Tensor([7.0, 8.0]), + torch.Tensor([10.0, 11.0, 12.0]), + ] + weights = [ + torch.Tensor([1.0]), + torch.Tensor(), + torch.Tensor([7.0, 8.0]), + torch.Tensor([10.0, 11.0, 12.0]), + ] + j0 = JaggedTensor.from_dense( + values=values, + weights=weights, + ) + elems, spec = pytree.tree_flatten(j0) + j1 = pytree.tree_unflatten(elems, spec) + + self.assertTrue(torch.equal(j0.lengths(), j1.lengths())) + self.assertTrue(torch.equal(j0.weights(), j1.weights())) + self.assertTrue(torch.equal(j0.values(), j1.values())) + + def test_to_dict_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + stride_per_key_per_rank = [[2], [4]] + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + jag_tensor_dict = jag_tensor.to_dict() + j0 = jag_tensor_dict["index_0"] + j1 = jag_tensor_dict["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.5, 1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_empty(self) -> None: + keys = ["index_0"] + values = torch.tensor([]) + lengths = torch.tensor([]) + offsets = torch.tensor([]) + + kjt_0 = KeyedJaggedTensor(keys=keys, values=values, lengths=lengths) + j0 = kjt_0["index_0"] + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + + keys = ["index_1"] + kjt_1 = KeyedJaggedTensor(keys=keys, values=values, offsets=offsets) + j1 = kjt_1["index_1"] + + self.assertTrue(isinstance(j1, JaggedTensor)) + self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) + + combined_kjt = KeyedJaggedTensor.concat([kjt_0, kjt_1]) + j0 = combined_kjt["index_0"] + j1 = combined_kjt["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertTrue(isinstance(j1, JaggedTensor)) + self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) + + kjt_2 = KeyedJaggedTensor.empty() + self.assertEqual(kjt_2.to_dict(), {}) + + kjt_from_script = torch.jit.script(KeyedJaggedTensor.empty)() + kjt_like = torch.jit.script(KeyedJaggedTensor.empty_like)(kjt_from_script) + self.assertEqual(kjt_from_script.to_dict(), {}) + self.assertEqual(kjt_like.to_dict(), {}) + + def test_empty_to_dict(self) -> None: + keys = ["index_0", "index_1"] + values = torch.tensor([]) + lengths = torch.tensor([[], []]) + length_per_key = [0, 0] + + jag_tensor = KeyedJaggedTensor( + keys=keys, values=values, lengths=lengths, length_per_key=length_per_key + ) + jag_tensor_dict = jag_tensor.to_dict() + j0 = jag_tensor_dict["index_0"] + j1 = jag_tensor_dict["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.offsets(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertTrue(isinstance(j1, JaggedTensor)) + self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.offsets(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + keys=keys, values=values, lengths=lengths + ) + jag_tensor_dict = jag_tensor.to_dict() + j0 = jag_tensor_dict["index_0"] + j1 = jag_tensor_dict["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.offsets(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertTrue(isinstance(j1, JaggedTensor)) + self.assertTrue(torch.equal(j1.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.offsets(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) + + def test_split(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0, j1 = jag_tensor.split([1, 1]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), ["index_0"]) + self.assertEqual(j1.keys(), ["index_1"]) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_empty_vb(self) -> None: + keys = ["index_0"] + values = torch.tensor([]) + lengths = torch.tensor([]) + stride_per_key_per_rank = [[]] + + kjt_0 = KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + self.assertTrue(torch.equal(kjt_0.lengths(), torch.Tensor([]))) + self.assertTrue(torch.equal(kjt_0.values(), torch.Tensor([]))) + self.assertEqual(kjt_0.stride(), 0) + + def test_split_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) + keys = ["index_0", "index_1", "index_2", "index_3"] + lengths = torch.IntTensor([2, 0, 1, 1, 1, 3, 0, 2]) + stride_per_key_per_rank = [[3], [0], [1], [4]] + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + j0, j1, j2 = jag_tensor.split([1, 1, 2]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), ["index_0"]) + self.assertEqual(j1.keys(), ["index_1"]) + self.assertEqual(j2.keys(), ["index_2", "index_3"]) + self.assertEqual(j0.stride(), 4) + self.assertEqual(j1.stride(), 4) + self.assertEqual(j2.stride(), 4) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([]))) + self.assertTrue(torch.equal(j2.lengths(), torch.IntTensor([1, 1, 3, 0, 2]))) + self.assertTrue( + torch.equal(j2.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])) + ) + + j0, j1, j2, j3 = jag_tensor.split([0, 3, 0, 1]) + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), []) + self.assertEqual(j1.keys(), ["index_0", "index_1", "index_2"]) + self.assertEqual(j2.keys(), []) + self.assertEqual(j3.keys(), ["index_3"]) + self.assertEqual(j0.stride(), 4) + self.assertEqual(j1.stride(), 4) + self.assertEqual(j2.stride(), 4) + self.assertEqual(j3.stride(), 4) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([1.0, 2.0, 3.0, 4.0]))) + self.assertTrue(torch.equal(j2.lengths(), torch.IntTensor([]))) + self.assertTrue(torch.equal(j2.values(), torch.Tensor([]))) + self.assertTrue(torch.equal(j3.lengths(), torch.IntTensor([1, 3, 0, 2]))) + self.assertTrue( + torch.equal(j3.values(), torch.Tensor([5.0, 6.0, 7.0, 8.0, 9.0, 10.0])) + ) + + def test_zero_split(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0, j1 = jag_tensor.split([0, 2]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), []) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertEqual(j0.stride(), 3) + + self.assertEqual(j1.keys(), ["index_0", "index_1"]) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1, 1, 3]))) + self.assertTrue(torch.equal(j1.weights(), weights)) + self.assertTrue(torch.equal(j1.values(), values)) + self.assertEqual(j1.stride(), 3) + + def test_permute_w_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + weights=weights, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.weights(), + torch.Tensor([1.5, 1.0, 0.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + ), + ) + + def test_permute(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_permute_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0]) + keys = ["index_0", "index_1", "index_2"] + stride_per_key_per_rank = [[2], [4], [3]] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 5, 6, 8], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor([2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 3, 0, 1, 1, 0, 0, 2, 0]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_permute_vb_duplicate(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0]) + keys = ["index_0", "index_1", "index_2"] + stride_per_key_per_rank = [[2], [4], [3]] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + indices = [1, 1, 0, 0, 2, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor( + [ + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 1.0, + 1.0, + 7.0, + 8.0, + 7.0, + 8.0, + ] + ), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_permute_duplicates(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2, 1, 1] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_0", "index_2", "index_1", "index_1"], + ) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8, 11, 14], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor( + [ + 3.0, + 4.0, + 5.0, + 1.0, + 2.0, + 6.0, + 7.0, + 8.0, + 3.0, + 4.0, + 5.0, + 3.0, + 4.0, + 5.0, + ] + ), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_concat(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + keys = ["index_0", "index_1", "index_2"] + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0, 0, 1, 0]) + + kjt_expected = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + kjt_actual = KeyedJaggedTensor.concat( + [ + KeyedJaggedTensor.from_lengths_sync( + values=values[:4], + keys=keys[:1], + lengths=lengths[:4], + ), + KeyedJaggedTensor.from_lengths_sync( + values=values[4:], + keys=keys[1:], + lengths=lengths[4:], + ), + ], + ) + self.assertTrue(torch.equal(kjt_expected.lengths(), kjt_actual.lengths())) + self.assertTrue(torch.equal(kjt_expected.offsets(), kjt_actual.offsets())) + self.assertTrue(torch.equal(kjt_expected.values(), kjt_actual.values())) + # pyre-ignore[6] + self.assertListEqual(kjt_expected._length_per_key, kjt_actual._length_per_key) + + def test_concat_fxable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedJaggedTensor]) -> KeyedJaggedTensor: + return KeyedJaggedTensor.concat(inputs) + + m = MyModule() + + # input + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + keys = ["index_0", "index_1", "index_2"] + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0, 0, 1, 0]) + kjt_1 = KeyedJaggedTensor.from_lengths_sync( + values=values[:4], + keys=keys[:1], + lengths=lengths[:4], + ) + kjt_2 = KeyedJaggedTensor.from_lengths_sync( + values=values[4:], + keys=keys[1:], + lengths=lengths[4:], + ) + inputs = [kjt_1, kjt_2] + + # ensure that symbolic tracing works + gm = torch.fx.symbolic_trace(m) + kjt_expected = m(inputs) + kjt_actual = gm(inputs) + + self.assertTrue(torch.equal(kjt_expected.lengths(), kjt_actual.lengths())) + self.assertTrue(torch.equal(kjt_expected.offsets(), kjt_actual.offsets())) + self.assertTrue(torch.equal(kjt_expected.values(), kjt_actual.values())) + self.assertListEqual(kjt_expected._length_per_key, kjt_actual._length_per_key) + + def test_length_vs_offset(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]) + + j_offset = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + + j_lens = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + weights=weights, + ) + + self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) + # TO DO: T88149179 + self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) + + def test_2d(self) -> None: + values = torch.Tensor([[i * 0.5, i * 1.0, i * 1.5] for i in range(1, 9)]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + j = KeyedJaggedTensor.from_offsets_sync( + values=values, + weights=weights, + keys=keys, + offsets=offsets, + ) + j_0 = j["index_0"] + + self.assertTrue(torch.equal(j_0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue( + torch.equal( + j_0.values(), + torch.Tensor( + [ + [0.5, 1.0, 1.5], + [1.0, 2.0, 3.0], + [1.5, 3.0, 4.5], + ], + ), + ) + ) + + def test_float_lengths_offsets_throws(self) -> None: + values = torch.rand((7, 3)) + keys = ["f1", "f2"] + # torch.Tensor([3, 4]) also fails + # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd + # param but got `Type[float]`. + lengths = torch.tensor([3, 4], dtype=float) + # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd + # param but got `Type[float]`. + offsets = torch.tensor([0, 3, 7], dtype=float) + + with self.assertRaises(AssertionError): + KeyedJaggedTensor.from_lengths_sync( + keys=keys, values=values, lengths=lengths + ) + with self.assertRaises(AssertionError): + KeyedJaggedTensor.from_offsets_sync( + keys=keys, values=values, offsets=offsets + ) + + def test_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedJaggedTensor) -> torch.Tensor: + values = input["any"].values() + return values + + m = MyModule() + torch.jit.script(m) + + def test_to(self) -> None: + j = KeyedJaggedTensor.from_offsets_sync( + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.arange(8), + weights=torch.arange(8 * 10), + keys=["index_0", "index_1"], + ) + j2 = j.to(device=torch.device("cpu")) + self.assertTrue(torch.equal(j.offsets(), j2.offsets())) + self.assertTrue(torch.equal(j.lengths(), j2.lengths())) + self.assertTrue(torch.equal(j.values(), j2.values())) + self.assertTrue(torch.equal(j.weights(), j2.weights())) + + def test_string_none(self) -> None: + jag_tensor = KeyedJaggedTensor( + [], + torch.Tensor(), + ) + + self.assertEqual( + str(jag_tensor), + """KeyedJaggedTensor()\n""", + ) + + def test_string_basic(self) -> None: + values = torch.Tensor([1.0]) + keys = ["key"] + offsets = torch.IntTensor([0, 1]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + """KeyedJaggedTensor({\n "key": [[1.0]]\n})\n""", + ) + + def test_string_values(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + 'KeyedJaggedTensor({\n "index_0": [[1.0, 2.0], [], [3.0]],\n' + ' "index_1": [[4.0], [5.0], [6.0, 7.0, 8.0]]\n})\n', + ) + + def test_string_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + + self.assertEqual( + str(jag_tensor), + 'KeyedJaggedTensor({\n "index_0": {\n' + ' "values": [[1.0, 2.0], [], [3.0]],\n' + ' "weights": [[1.0, 0.5], [], [1.5]]\n' + ' },\n "index_1": {\n' + ' "values": [[4.0], [5.0], [6.0, 7.0, 8.0]],\n' + ' "weights": [[1.0], [0.5], [1.0, 1.0, 1.5]]\n }\n})\n', + ) + + def test_string_vb(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + stride_per_key_per_rank = [[1, 1], [1, 3]] + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + self.assertEqual( + str(jag_tensor), + 'KeyedJaggedTensor({\n "index_0": {\n ' + '"values": [[1.0, 2.0], []],\n ' + '"weights": [[1.0, 0.5], []]\n },\n ' + '"index_1": {\n ' + '"values": [[3.0], [4.0], [5.0], [6.0, 7.0, 8.0]],\n ' + '"weights": [[1.5], [1.0], [0.5], [1.0, 1.0, 1.5]]\n }\n})\n', + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_record_stream(self) -> None: + j = KeyedJaggedTensor.from_offsets_sync( + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.arange(8), + weights=torch.arange(8 * 10), + keys=["index_0", "index_1"], + ).to(torch.device("cuda")) + j.record_stream(torch.cuda.current_stream()) + + def test_equality(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]) + """ + KJT looks like, represented from the inputs above + # 0 1 2 3 <-- dim_1 + # "index_0" None [1.0, 2.0] None [3.0] + # "index_1" [4.0] [5.0] None [1.0, 1.0, 1.5] + # ^ + # dim_0 + """ + kt = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + ) + + kt_2 = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + kt_3 = KeyedJaggedTensor( + values=values, + keys=["index_1", "index_0"], + offsets=offsets, + ) + + kt_4 = KeyedJaggedTensor( + values=torch.Tensor([10.0, 4.0, 2.0, 5.0, 2.0, 6.0, 9.0, 8.0]), + keys=keys, + lengths=lengths, + ) + + kt_5 = KeyedJaggedTensor( + values=values, + keys=["index_0"], + offsets=offsets, + ) + + weighted_kt = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + + self.assertTrue(kjt_is_equal(kt, kt_2)) # base check + self.assertFalse(kjt_is_equal(kt, kt_3)) # different order of keys + self.assertFalse(kjt_is_equal(kt, kt_4)) # different values + self.assertFalse(kjt_is_equal(kt, kt_5)) # different keys + self.assertFalse(kjt_is_equal(kt, weighted_kt)) # different weights + + # Different lengths + lengths = torch.IntTensor([1, 2, 3, 4, 5, 6, 7, 8]) + lengths_2 = torch.IntTensor([8, 7, 6, 5, 4, 3, 2, 1]) + kt_length_1 = KeyedJaggedTensor.from_lengths_sync( + values=values, keys=keys, lengths=lengths + ) + kt_length_2 = KeyedJaggedTensor.from_lengths_sync( + values=values, keys=keys, lengths=lengths_2 + ) + self.assertFalse(kjt_is_equal(kt_length_1, kt_length_2)) + + # Different offsets + offsets_2 = torch.IntTensor([8, 4, 1, 5, 0, 1, 2, 1, 2]) + kt_offset_1 = KeyedJaggedTensor.from_offsets_sync( + values=values, keys=keys, offsets=offsets + ) + kt_offset_2 = KeyedJaggedTensor.from_offsets_sync( + values=values, keys=keys, offsets=offsets_2 + ) + self.assertFalse(kjt_is_equal(kt_offset_1, kt_offset_2)) + + # Different length_per_key and offset_per_key + length_per_key_1 = [4, 4] + length_per_key_2 = [3, 5] + offset_per_key_1 = [0, 4] + offset_per_key_2 = [0, 3] + kt_lpk_opk_1 = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + length_per_key=length_per_key_1, + offset_per_key=offset_per_key_1, + ) + kt_lpk_opk_2 = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + length_per_key=length_per_key_2, + offset_per_key=offset_per_key_2, + ) + self.assertFalse(kjt_is_equal(kt_lpk_opk_1, kt_lpk_opk_2)) + + # None values in optional fields + kt_none_fields = KeyedJaggedTensor(values=values, keys=keys, offsets=offsets) + kt_some_fields = KeyedJaggedTensor( + values=values, keys=keys, offsets=offsets, lengths=lengths, weights=weights + ) + self.assertFalse(kjt_is_equal(kt_none_fields, kt_some_fields)) + + # Empty KeyedJaggedTensor + kt_empty = KeyedJaggedTensor( + values=torch.Tensor([]), keys=[], offsets=torch.IntTensor([]) + ) + self.assertTrue(kjt_is_equal(kt_empty, kt_empty)) + self.assertFalse(kjt_is_equal(kt, kt_empty)) + + # Non-KeyedJaggedTensor input + non_kjt_input = "not a KeyedJaggedTensor instance" + self.assertFalse(kjt_is_equal(kt, non_kjt_input)) + + def test_meta_device_compatibility(self) -> None: + keys = ["index_0", "index_1", "index_2", "index_3"] + lengths = torch.tensor( + [2, 0, 1, 1, 1, 3, 0, 2], + device=torch.device("meta"), + ) + offsets = torch.tensor( + [0, 2, 2, 3, 4, 5, 8, 8, 10], + device=torch.device("meta"), + ) + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + device=torch.device("meta"), + ) + weights = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + device=torch.device("meta"), + ) + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + ) + + kjt.sync() + kjt.unsync() + + jt_dict = kjt.to_dict() + kjt = KeyedJaggedTensor.from_jt_dict(jt_dict) + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=keys, values=values, weights=weights, lengths=lengths + ) + + kjt = KeyedJaggedTensor.from_offsets_sync( + keys=keys, values=values, weights=weights, offsets=offsets + ) + + # test empty keys case + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[], + values=torch.tensor([], device=torch.device("meta")), + lengths=torch.tensor([], device=torch.device("meta")), + ) + + +class TestKeyedJaggedTensorScripting(unittest.TestCase): + def test_scriptable_forward(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedJaggedTensor) -> KeyedJaggedTensor: + input["any"].values() + input.dist_labels() + input.dist_splits([1, 2]) + return KeyedJaggedTensor.dist_init( + keys=input.keys(), + tensors=input.dist_tensors(), + variable_stride_per_key=False, + num_workers=2, + recat=torch.tensor([]), + stride_per_rank=[2, 3], + ) + + m = MyModule() + torch.jit.script(m) + + def test_scriptable_split(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedJaggedTensor) -> List[KeyedJaggedTensor]: + return input.split([1, 0, 1]) + + m = MyModule() + torch.jit.script(m) + + def test_scriptable_init(self) -> None: + def create_kjt() -> KeyedJaggedTensor: + return KeyedJaggedTensor.from_offsets_sync( + values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.tensor([0, 0, 2, 2, 3, 4, 5, 5, 8], dtype=torch.int32), + ) + + def create_vb_kjt() -> KeyedJaggedTensor: + return KeyedJaggedTensor.from_offsets_sync( + values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.tensor([0, 0, 2, 2, 3, 4, 5, 5, 8], dtype=torch.int32), + stride_per_key_per_rank=[[2], [4]], + ) + + # assert that we can script KJT creation + torch.jit.script(create_kjt) + torch.jit.script(create_vb_kjt) + + def test_scriptable_empty(self) -> None: + def create_empty() -> KeyedJaggedTensor: + return KeyedJaggedTensor.empty() + + def create_empty_weighted() -> KeyedJaggedTensor: + return KeyedJaggedTensor.empty(is_weighted=True) + + # assert that we can script KJT creation + torch.jit.script(create_empty) + torch.jit.script(create_empty_weighted) + + +class TestKeyedJaggedTensorTracingScripting(unittest.TestCase): + def test_jit_tracable(self) -> None: + # This module will simply go through the constructor of the + # KeyedJaggedTensor to construct it with multiple different batch sizes + class MyModule(torch.nn.Module): + def forward( + self, offsets: torch.Tensor, values: torch.Tensor, weights: torch.Tensor + ) -> torch.Tensor: + j = KeyedJaggedTensor.from_offsets_sync( + offsets=offsets, + values=values, + weights=weights, + keys=["index_0", "index_1"], + ) + return j["index_0"].offsets() + + sample_2 = ( + torch.tensor([0, 2, 2]), + torch.arange(2), + torch.arange(2 * 10), + ) + sample_6 = ( + torch.tensor([0, 2, 2, 3, 4, 6, 8]), + torch.arange(8), + torch.arange(8 * 10), + ) + m = MyModule() + model_eager_traced: torch.jit.ScriptModule = torch.jit.trace( + m, sample_2, strict=False + ) + self.assertTrue( + torch.equal(model_eager_traced(*sample_2), torch.tensor([0, 2])) + ) + self.assertTrue( + torch.equal(model_eager_traced(*sample_6), torch.tensor([0, 2, 2, 3])) + ) + + def test_create_and_access_keyed_jagged_tensor(self) -> None: + class ModuleCreateAndAccessKeyedJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: int) -> int: + features = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + return ( + len(features.keys()) + + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + # Case 4: KeyedJaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessKeyedJaggedTensor() + gm = symbolic_trace(m) + FileCheck().check("return 35").check_not("KeyedJaggedTensor").run(gm.code) + ref_out = m(8) + traced_out = gm(8) + self.assertEqual(ref_out, traced_out) + torch.jit.script(gm) + + def test_create_and_access_empty_keyed_jagged_tensor(self) -> None: + class ModuleCreateAndAccessEmptyKeyedJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: int) -> int: + features = KeyedJaggedTensor.empty(is_weighted=True) + return ( + len(features.keys()) + + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + # Case 4: KeyedJaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessEmptyKeyedJaggedTensor() + gm = symbolic_trace(m) + FileCheck().check("return 1").check_not("KeyedJaggedTensor").run(gm.code) + ref_out = m(8) + traced_out = gm(8) + self.assertEqual(ref_out, traced_out) + torch.jit.script(gm) + + def test_traceable_empty_like(self) -> None: + class ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, kjt: KeyedJaggedTensor) -> int: + features = KeyedJaggedTensor.empty_like(kjt) + return ( + len(features.keys()) + + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + # Case 4: KeyedJaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor() + kjt = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + gm = symbolic_trace(m) + ref_out = m(kjt) + traced_out = gm(kjt) + self.assertEqual(ref_out, traced_out) + torch.jit.script(gm) + + def test_use_keyed_jagged_tensor_as_input_and_output(self) -> None: + class ModuleUseKeyedJaggedTensorAsInputAndOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, input: KeyedJaggedTensor + ) -> Tuple[KeyedJaggedTensor, int]: + output = KeyedJaggedTensor( + input.keys(), + input.values(), + input.weights(), + lengths=input.lengths(), + offsets=input.offsets(), + ) + return output, output.stride() + + # Case 3: KeyedJaggedTensor is used as both an input and an output of the root module. + m = ModuleUseKeyedJaggedTensorAsInputAndOutput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check( + "stride" + ).run(gm.code) + input = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + ref_out = m(input) + traced_out = gm(input) + self.assertEqual(ref_out[1], traced_out[1]) + torch.jit.script(gm) + + def test_use_keyed_jagged_tensor_as_input(self) -> None: + class ModuleUseKeyedJaggedTensorAsInput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: KeyedJaggedTensor) -> int: + return ( + len(input.keys()) + + input.values().numel() + + input.weights().numel() + + input.lengths().numel() + + input.offsets().numel() + ) + + # Case 2: KeyedJaggedTensor is only used as an input of the root module. + m = ModuleUseKeyedJaggedTensorAsInput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check("keys()").check("len").check( + "values()" + ).check("numel()").run(gm.code) + + input = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + ref_out = m(input) + traced_out = gm(input) + self.assertEqual(ref_out, traced_out) + torch.jit.script(gm) + + def test_use_keyed_jagged_tensor_as_output(self) -> None: + class ModuleUseKeyedJaggedTensorAsOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + keys: List[str], + values: torch.Tensor, + weights: torch.Tensor, + lengths: torch.Tensor, + ) -> Tuple[KeyedJaggedTensor, int]: + output = KeyedJaggedTensor(keys, values, weights, lengths) + return output, output.stride() + + # Case 1: KeyedJaggedTensor is only used as an output of the root module. + m = ModuleUseKeyedJaggedTensorAsOutput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check( + "return (keyed_jagged_tensor," + ).run(gm.code) + + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + lengths = torch.IntTensor([2, 0, 1, 1, 1, 3]) + + ref_out = m(keys, values, weights, lengths) + traced_out = gm(keys, values, weights, lengths) + + self.assertEqual(ref_out[1], traced_out[1]) + self.assertTrue(torch.equal(traced_out[0].offsets(), ref_out[0].offsets())) + torch.jit.script(gm) + + +class TestComputeKJTToJTDict(unittest.TestCase): + def test_key_lookup(self) -> None: + m = ComputeKJTToJTDict() + input = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + stride_per_key_per_rank=[[0, 2], [3, 3]], + ) + + out = m(input) + + i0 = out["index_0"] + self.assertTrue(torch.equal(i0._values, torch.tensor([1.0, 2.0]))) + self.assertTrue(torch.equal(i0._weights, torch.tensor([1.0, 0.5]))) + self.assertTrue(torch.equal(i0._lengths, torch.tensor([0, 2]))) + self.assertTrue(torch.equal(i0._offsets, torch.tensor([0, 0, 2]))) + + i1 = out["index_1"] + self.assertTrue( + torch.equal(i1._values, torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) + ) + self.assertTrue( + torch.equal(i1._weights, torch.tensor([1.5, 1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue(torch.equal(i1._lengths, torch.tensor([0, 1, 1, 1, 0, 3]))) + self.assertTrue(torch.equal(i1._offsets, torch.tensor([0, 0, 1, 2, 3, 3, 6]))) + + +@skip_if_asan_class +class TestKeyedJaggedTensorGPU(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.device = torch.cuda.current_device() + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute_vb(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + stride_per_key_per_rank = [[2], [4], [3]] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 5, 6, 8], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute_vb_duplicate(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + stride_per_key_per_rank = [[2], [4], [3]] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + indices = [1, 1, 0, 0, 2, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values().cpu(), + torch.Tensor( + [ + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 1.0, + 1.0, + 7.0, + 8.0, + 7.0, + 8.0, + ] + ), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths().cpu(), + torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute_duplicates(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2, 1, 1] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_0", "index_2", "index_1", "index_1"], + ) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8, 11, 14], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [ + 3.0, + 4.0, + 5.0, + 1.0, + 2.0, + 6.0, + 7.0, + 8.0, + 3.0, + 4.0, + 5.0, + 3.0, + 4.0, + 5.0, + ], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), + [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) diff --git a/torchrec/sparse/tests/test_keyed_tensor.py b/torchrec/sparse/tests/test_keyed_tensor.py new file mode 100644 index 000000000..027f2634c --- /dev/null +++ b/torchrec/sparse/tests/test_keyed_tensor.py @@ -0,0 +1,1047 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest +from typing import Callable, Dict, List, Tuple + +import torch +import torch.utils._pytree as pytree +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torch.fx._pytree import tree_flatten_spec +from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, + _kt_regroup_arguments, + _regroup_keyed_tensors, + KeyedTensor, + permute_multi_embedding, + regroup_kts, +) +from torchrec.sparse.tests.utils import build_groups, build_kts +from torchrec.test_utils import skip_if_asan_class + +torch.fx.wrap("len") + + +class TestKeyedTensor(unittest.TestCase): + def test_key_lookup(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + self.assertEqual(kt.key_dim(), 0) + + self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) + self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) + + def test_key_lookup_dim_1(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) + self.assertEqual(kt.key_dim(), 1) + self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) + self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) + + def test_to_dict(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + self.assertEqual(kt.key_dim(), 0) + + d = kt.to_dict() + for key in keys: + self.assertTrue(torch.equal(kt[key], d[key])) + + def test_to_dict_dim_1(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) + self.assertEqual(kt.key_dim(), 1) + + d = kt.to_dict() + for key in keys: + self.assertTrue(torch.equal(kt[key], d[key])) + + def test_regroup_single_kt(self) -> None: + tensor_list = [torch.randn(2, 3) for i in range(5)] + key_dim = 1 + keys = ["dense_0", "dense_1", "dense_2", "dense_3", "dense_4"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt], [["dense_0", "dense_4"], ["dense_1", "dense_3"], ["dense_2"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], torch.cat([tensor_list[0], tensor_list[4]], key_dim) + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], torch.cat([tensor_list[1], tensor_list[3]], key_dim) + ) + ) + self.assertTrue(torch.equal(grouped_tensors[2], tensor_list[2])) + + def test_regroup_multiple_kt(self) -> None: + key_dim = 1 + tensor_list_1 = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3), torch.randn(2, 10)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt_1, kt_2], [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], + torch.cat( + [tensor_list_1[0], tensor_list_2[1], tensor_list_1[2]], key_dim + ), + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], + torch.cat([tensor_list_1[1], tensor_list_2[0]], key_dim), + ) + ) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + regroup_func=st.sampled_from( + [ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + _fbgemm_permute_pooled_embs, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=15, deadline=None) + def test_regroup_kts( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=device, + run_backward=False, + ) + groups = build_groups(kts=kts, num_groups=2) + refs = _regroup_keyed_tensors(kts, groups) + outputs = regroup_func(kts, groups) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + regroup_func=st.sampled_from( + [ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + _fbgemm_permute_pooled_embs, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_regroup_kts_inference( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + with torch.inference_mode(): + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=device, + run_backward=False, + ) + groups = build_groups(kts=kts, num_groups=2) + refs = _regroup_keyed_tensors(kts, groups) + outputs = regroup_func(kts, groups) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + + def test_regroup_backward_skips_and_duplicates(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True) + labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "cuda"]), + regroup_func=st.sampled_from( + [ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + _fbgemm_permute_pooled_embs, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_regroup_backward( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=device, + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) + labels = torch.randint(0, 1, (128,), device=device).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = regroup_func(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_multiple_kt_duplicate_keys(self) -> None: + key_dim = 1 + tensor_list_1 = [torch.randn(2, 4) for i in range(2)] + keys_1 = ["dense_0", "dense_1"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(3)] + keys_2 = ["sparse_0", "sparse_1", "dense_2"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt_1, kt_2], [["dense_0", "sparse_1"], ["dense_1", "sparse_0", "dense_0"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], + torch.cat([tensor_list_1[0], tensor_list_2[1]], key_dim), + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], + torch.cat( + [tensor_list_1[1], tensor_list_2[0], tensor_list_1[0]], key_dim + ), + ) + ) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + regroup_func=st.sampled_from( + [ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_regroup_scriptable( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]: + # user provided, not model input + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + return regroup_func(inputs, groups) + + m = MyModule() + script_model = torch.jit.script(m) + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3, device=device) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3, device=device) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + outputs = script_model(inputs) # pyre-ignore[29] + refs = _regroup_keyed_tensors( + inputs, [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + ) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + regroup_func=st.sampled_from( + [ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_regroup_scriptable_inference( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]: + # user provided, not model input + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + return regroup_func(inputs, groups) + + m = MyModule() + script_model = torch.jit.script(m) + with torch.inference_mode(): + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3, device=device) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3, device=device) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + outputs = script_model(inputs) # pyre-ignore[29 + refs = _regroup_keyed_tensors( + inputs, [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + ) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + + def test_regroup_fxable(self) -> None: + class MyModule(torch.nn.Module): + def forward( + self, inputs: List[KeyedTensor], groups: List[List[str]] + ) -> List[torch.Tensor]: + return KeyedTensor.regroup(inputs, groups) + + m = MyModule() + + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + + # ensure that symbolic tracing works + gm = torch.fx.symbolic_trace(m) + results = m(inputs, groups) + traced_results = gm(inputs, groups) + self.assertEqual(len(results), len(traced_results)) + for result, traced_result in zip(results, traced_results): + self.assertTrue(torch.equal(result, traced_result)) + + def test_regroup_as_dict_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + keys = ["group_0", "group_1"] + return KeyedTensor.regroup_as_dict(inputs, groups, keys) + + m = MyModule() + torch.jit.script(m) + + def test_regroup_as_dict_fxable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + keys = ["group_0", "group_1"] + return KeyedTensor.regroup_as_dict(inputs, groups, keys) + + m = MyModule() + + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + + # ensure that symbolic tracing works + gm = torch.fx.symbolic_trace(m) + results = m(inputs) + traced_results = gm(inputs) + self.assertEqual(len(results), len(traced_results)) + for result, traced_result in zip(results.values(), traced_results.values()): + self.assertTrue(torch.equal(result, traced_result)) + + def test_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedTensor) -> torch.Tensor: + values = input["any"].values() + return values + + m = MyModule() + torch.jit.script(m) + + def test_string_none(self) -> None: + jag_tensor = KeyedTensor( + [], + [], + torch.Tensor(), + ) + + self.assertEqual( + str(jag_tensor), + "KeyedTensor()\n", + ) + + def test_string_basic(self) -> None: + tensor_list = [ + torch.tensor([[1.0]]), + ] + keys = ["key"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=0) + + self.assertEqual( + str(kt), + 'KeyedTensor({\n "key": [[1.0]]\n})\n', + ) + + def test_string_values(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list) + + self.assertEqual( + str(kt), + 'KeyedTensor({\n "dense_0": [[1.0], [1.0]],\n "dense_1": [[2.0, 3.0], [2.0, 3.0]]\n})\n', + ) + + def test_pytree(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]).T, + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1) + # generate the out_spec in the torch.export run + flattened, out_spec = pytree.tree_flatten(kt) + + # first element of flattened list should be the kt._values + self.assertTrue(torch.equal(flattened[0], kt.values())) + # re-construct the unflattened kt from the flattened list plus the out_spec + unflattened = pytree.tree_unflatten(flattened, out_spec) + + self.assertTrue(isinstance(unflattened, KeyedTensor)) + self.assertListEqual(unflattened.keys(), keys) + self.assertListEqual(unflattened._length_per_key, kt._length_per_key) + + # for ir export, key order in KT could change + tensor_list = [ + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T, + torch.Tensor([[1.0, 1.0]]).T, + ] + keys = ["dense_1", "dense_0"] + kt2 = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1) + + # flatten the kt2 based on previously generated out_spec + # this is to mimic the exported_program module run + # the kt2 could have different key order but out_spec is the same + flattened2 = tree_flatten_spec(kt2, out_spec) + + # re-construct the unflattened kt from the flattened list plus the out_spec + # the rebuilt kt2 should contain the same effective data as kt (ignoring key order) + unflattened2 = pytree.tree_unflatten(flattened2, out_spec) + self.assertTrue(isinstance(unflattened2, KeyedTensor)) + self.assertSetEqual(set(unflattened.keys()), set(unflattened2.keys())) + for key in kt.keys(): + torch.testing.assert_close(unflattened[key], unflattened2[key]) + torch.testing.assert_close(kt[key], unflattened2[key]) + + +class TestKeyedTensorRegroupOp(unittest.TestCase): + # pyre-ignore[56] + @given(device_str=st.sampled_from(["cpu", "meta", "cuda"])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_kt_regroup_arguments(self, device_str: str) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + torch.empty(0, device=device), keys, lengths, groups + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + if device_str == "meta": + self.assertEqual(permutes.shape, (len(ref_permutes), len(ref_permutes[0]))) + self.assertEqual(in_shapes.shape, (3,)) + self.assertEqual(out_shapes.shape, (4,)) + else: + self.assertTrue( + torch.equal( + permutes, + torch.tensor(ref_permutes, dtype=torch.int32, device=device), + ) + ) + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + batch_size=st.sampled_from([16, 128, 1024]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + with torch.inference_mode(): + values = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + + if device_str == "meta": + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + else: + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out, in_start, _, length, _ = permutes[i].tolist() + refs[out].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + for out, ref in zip(outputs, refs): + torch.testing.assert_close(out, ref) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["meta", "cpu", "cuda"]), + dtype=st.sampled_from( + [ + torch.float, + torch.float32, + torch.float16, + torch.bfloat16, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_multi_permute_dtype(self, device_str: str, dtype: torch.dtype) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + + batch_size = 4 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(L), device=device, dtype=dtype) for L in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + + if device_str == "meta": + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + else: + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out, in_start, _, length, _ = permutes[i].tolist() + refs[out].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + for out, ref in zip(outputs, refs): + torch.testing.assert_close(out, ref) + self.assertEqual(out.dtype, ref.dtype) + + # pyre-ignore[56] + @given( + zipped_args=st.sampled_from( + [ + ("cpu", 32, [[3, 4], [5, 6, 7], [8]]), + ("cuda", 128, [[96, 256], [512, 128, 768], [1024]]), + ], + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_multi_permute_backward( + self, zipped_args: Tuple[str, int, List[List[int]]] + ) -> None: + device_str, batch_size, lengths = zipped_args + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device=device, requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "cuda"]), + batch_size=st.sampled_from([16, 1024]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_multi_permute_noncontiguous( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(sum(lens), batch_size, device=device, requires_grad=True) + for lens in lengths + ] + non_contiguous = [v.t() for v in values] + for value in non_contiguous: + self.assertFalse(value.is_contiguous()) + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + non_contiguous[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][in_start : (in_start + length), :]) + refs = [torch.cat(ref).t() for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + non_contiguous, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + batch_size=st.sampled_from([16, 1024]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_kt_regroup_arguments_op(self, device_str: str, batch_size: int) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + device = torch.device(device) + embs = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + permutes, in_shapes, out_shapes, out_lengths = ( + torch.ops.fbgemm.kt_regroup_arguments( + embs[0], + keys, + lengths, + groups, + ) + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + if device_str == "meta": + self.assertEqual(permutes.shape, (len(ref_permutes), len(ref_permutes[0]))) + self.assertEqual(in_shapes.shape, (3,)) + self.assertEqual(out_shapes.shape, (4,)) + else: + self.assertTrue( + torch.equal( + permutes, + torch.tensor(ref_permutes, dtype=torch.int32, device=device), + ) + ) + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "meta", "cuda"]), + batch_size=st.sampled_from([16, 1024]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_keyed_tensor_regroup_forward( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + with torch.inference_mode(): + values = [ + torch.randn(batch_size, sum(lens), device=device) for lens in lengths + ] + refs = [[] for _ in groups] + for p in permutes: + in_idx, out_idx, in_start, _, length, _ = p + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + for out, ref in zip(outputs, refs): + if device_str == "meta": + self.assertEqual(out.shape, ref.shape) + else: + torch.testing.assert_close(out, ref) + + # pyre-ignore[56] + @given( + device_str=st.sampled_from(["cpu", "cuda"]), + batch_size=st.sampled_from([16, 1024]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_keyed_tensor_regroup_backward( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device=device, requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + +@skip_if_asan_class +class TestKeyedTensorGPU(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.device = torch.cuda.current_device() + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_regroup_backward_skips_and_duplicates(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=self.device, + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True) + labels = torch.randint(0, 1, (128,), device=self.device).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_regroup_backward(self) -> None: + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=self.device, + run_backward=True, + ) + groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) + labels = torch.randint(0, 1, (128,), device=self.device).float() + + tensor_groups = KeyedTensor.regroup(kts, groups) + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, labels).sum() + actual_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + actual_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + # clear grads are return + kts[0].values().grad = None + kts[1].values().grad = None + + tensor_groups = _regroup_keyed_tensors(kts, groups) + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, labels).sum() + expected_kt_0_grad = torch.autograd.grad( + loss, kts[0].values(), retain_graph=True + )[0] + expected_kt_1_grad = torch.autograd.grad( + loss, kts[1].values(), retain_graph=True + )[0] + + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py new file mode 100644 index 000000000..ac6d21da4 --- /dev/null +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest + +import torch +from hypothesis import assume, given, settings, strategies as st, Verbosity +from tensordict import TensorDict +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt + + +class TestTensorDict(unittest.TestCase): + # pyre-ignore[56] + @given(device_str=st.sampled_from(["cpu", "meta", "cuda"])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_kjt_input(self, device_str: str) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) + kjt = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=values, + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7], device=device), + ) + features = maybe_td_to_kjt(kjt) + self.assertEqual(features, kjt) + + # pyre-ignore[56] + @given(device_str=st.sampled_from(["cpu", "meta", "cuda"])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_td_kjt(self, device_str: str) -> None: + # assumption only fails when using cuda but device == 0. + assume(device_str != "cuda" or torch.cuda.device_count() > 0) + device = torch.device(device_str) + values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) + lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device) + data = { + "f2": torch.nested.nested_tensor_from_jagged( + torch.tensor([2, 3], device=device), + lengths=torch.tensor([1, 1], device=device), + ), + "f1": torch.nested.nested_tensor_from_jagged( + torch.arange(2, device=device), + offsets=torch.tensor([0, 2, 2], device=device), + ), + "f3": torch.nested.nested_tensor_from_jagged( + torch.tensor([2, 3, 4], device=device), + lengths=torch.tensor([1, 2], device=device), + ), + } + td = TensorDict( + data, # type: ignore[arg-type] + device=device, + batch_size=[2], + ) + + features = maybe_td_to_kjt(td, ["f1", "f2", "f3"]) # pyre-ignore[6] + torch.testing.assert_close(features.values(), values) + torch.testing.assert_close(features.lengths(), lengths) diff --git a/torchrec/sparse/tests/utils.py b/torchrec/sparse/tests/utils.py new file mode 100644 index 000000000..2d9b73197 --- /dev/null +++ b/torchrec/sparse/tests/utils.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import functools +import random +import unittest +from typing import Any, Callable, List, Sequence + +import torch +from torchrec.sparse.jagged_tensor import KeyedTensor + + +def build_kts( + dense_features: int, + sparse_features: int, + dim_dense: int, + dim_sparse: int, + batch_size: int, + device: torch.device, + run_backward: bool, +) -> List[KeyedTensor]: + key_dim = 1 + dense_embs = [ + torch.randn(batch_size, dim_dense, device=device, requires_grad=run_backward) + for i in range(dense_features) + ] + dense_keys = [f"dense_{i}" for i in range(dense_features)] + dense_kt = KeyedTensor.from_tensor_list(dense_keys, dense_embs, key_dim) + + sparse_embs = [ + torch.randn(batch_size, dim_sparse, device=device, requires_grad=run_backward) + for i in range(sparse_features) + ] + sparse_keys = [f"sparse_{i}" for i in range(sparse_features)] + sparse_kt = KeyedTensor.from_tensor_list(sparse_keys, sparse_embs, key_dim) + return [dense_kt, sparse_kt] + + +def build_groups( + kts: List[KeyedTensor], + num_groups: int, + skips: bool = False, + duplicates: bool = False, +) -> List[List[str]]: + all_keys = [] + for kt in kts: + all_keys.extend(kt.keys()) + allocation = [random.randint(0, num_groups - 1) for _ in range(len(all_keys))] + groups = [[] for _ in range(num_groups)] + for i, key in enumerate(allocation): + groups[key].append(all_keys[i]) + if skips: + for group in groups: + if len(group) > 1: + group.pop(random.randint(0, len(group) - 1)) + if duplicates: + for group in groups: + group.append(random.choice(all_keys)) + return groups + + +def repeat_test( + *args: List[Any], **kwargs: Sequence[Any] +) -> Callable[..., Callable[..., None]]: + def decorate(f: Callable[..., None]) -> Callable[..., None]: + @functools.wraps(f) + def decorator(self: unittest.TestCase) -> None: + queue = [(arg, {}) for arg in args] if args else [((), {})] + for k, values in kwargs.items(): + new_queue = [] + for a, d in queue: + for v in values: + new_d = d.copy() + new_d[k] = v + new_queue.append((a, new_d)) + queue = new_queue + for a, d in queue: + print(f"running {f.__name__} {a} {d}") + f(self, *a, **d) + + return decorator + + return decorate diff --git a/torchrec/streamable.py b/torchrec/streamable.py index 9551f7bbb..075e5fdc6 100644 --- a/torchrec/streamable.py +++ b/torchrec/streamable.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import abc import torch @@ -18,7 +20,7 @@ class Multistreamable(abc.ABC): """ @abc.abstractmethod - def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + def record_stream(self, stream: torch.Stream) -> None: """ See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html """ @@ -39,8 +41,8 @@ class Pipelineable(Multistreamable): @abc.abstractmethod def to(self, device: torch.device, non_blocking: bool) -> "Pipelineable": """ - Please be aware that accoarding to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, - to might return self or a copy of self. So please remember to use `to` with the assignment operator, + Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, + `to` might return self or a copy of self. So please remember to use `to` with the assignment operator, for example, `in = in.to(new_device)`. """ ... diff --git a/torchrec/tensor_types.py b/torchrec/tensor_types.py new file mode 100644 index 000000000..6b88703d0 --- /dev/null +++ b/torchrec/tensor_types.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +# pyre-ignore-all-errors[2, 3, 4, 6, 13, 14, 20] + +import itertools +from typing import List, Tuple + +import torch +import torch._prims_common as utils + + +def down_size(N: int, size: torch.Size) -> Tuple[int, int]: + assert size[-1] % N == 0, f"{size} last dim not divisible by {N}" + return (*size[:-1], size[-1] // N) + + +def up_size(N: int, size: torch.Size) -> Tuple[int, int]: + return (*size[:-1], size[-1] * N) + + +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + + +def find_arg_of_type(it, t): + for x in it: + if isinstance(x, t): + return x + return None + + +class UIntXTensor(torch.Tensor): + """ + A Tensor subclass of uint8 dtype, that represents Tensor with X-bit elements. + The last dimension must be divisible by (8 // X). + + __torch_dispatch__ special handling: + .view(dtype=torch.uint8) - returns the underlying uint8 data. + + .slice,.view - works in UIntX units, dimension values must be divisible by (8 // X). + + .detach,.clone - work as an op on underlying uint8 data. + """ + + __torch_function__ = torch._C._disabled_torch_function_impl + + @staticmethod + def __new__(cls, N: int, elem): + assert elem.dtype is torch.uint8 + # pyre-ignore + return torch.Tensor._make_wrapper_subclass( + cls, up_size(N, elem.shape), dtype=torch.uint8 + ) + + def __init__(self, N: int, elem: torch.Tensor) -> None: + self.N: int = N + self.elem: torch.Tensor = elem + + # pyre-ignore + def tolist(self) -> List: + return self.elem.tolist() + + def __repr__(self) -> str: + return f"UInt{8 // self.N}Tensor(shape={self.shape}, elem={self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): # noqa: C901 + if func is torch.ops.aten.detach.default: + # Temp workaround to avoid 'Cannot set version_counter for inference tensor' + with torch.inference_mode(False): + (self,) = args + return cls(func(self.elem)) + elif func is torch.ops.aten.clone.default: + (self,) = args + return cls(func(self.elem)) + elif func is torch.ops.aten.copy_.default: + (self, src) = args + self.elem.copy_(src.elem) + return self + elif func is torch.ops.aten.view.dtype: + # .view(dtype=uint8) is the way to get the underlying uint8 data + self, dtype = args + if dtype == torch.uint8: + return self.elem + elif func is torch.ops.aten._to_copy.default: + (self,) = args + dtype = find_arg_of_type( + itertools.chain(args, kwargs.values()), torch.dtype + ) + device = find_arg_of_type( + itertools.chain(args, kwargs.values()), torch.device + ) + # Handle only to device + if device: + assert dtype is None or dtype == torch.uint8 + return cls(self.elem.to(device)) + elif func is torch.ops.aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == self.dim() - 1: + # hard case + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % self.N == 0, start + assert end >= self.shape[dim] or end % self.N == 0, end + return cls( + torch.ops.aten.slice.Tensor( + self.elem, dim, start // self.N, end // self.N, 1 + ), + ) + else: + return cls( + torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step), + ) + if func is torch.ops.aten.view.default: + self, size = args + size = utils.infer_size(size, self.numel()) + assert not kwargs + return cls(self.elem.reshape(down_size(self.N, size))) + elif func is torch.ops.aten.select.int: + self, dim, index = args + if dim != self.dim() - 1: + return cls(torch.ops.aten.select.int(self.elem, dim, index)) + else: + raise NotImplementedError(f"select dim={dim}") + + raise NotImplementedError(f"{func} args:{args} kwargs:{kwargs}") + + +class UInt4Tensor(UIntXTensor): + N: int = 2 + + @staticmethod + def __new__(cls, elem: torch.Tensor): + return UIntXTensor.__new__(cls, cls.N, elem) + + def __init__(self, elem: torch.Tensor) -> None: + super().__init__(UInt4Tensor.N, elem) + + +class UInt2Tensor(UIntXTensor): + N: int = 4 + + @staticmethod + def __new__(cls, elem: torch.Tensor): + return UIntXTensor.__new__(cls, cls.N, elem) + + def __init__(self, elem: torch.Tensor) -> None: + super().__init__(UInt2Tensor.N, elem) diff --git a/torchrec/test_utils/__init__.py b/torchrec/test_utils/__init__.py index ecc799d5b..78c210c1e 100644 --- a/torchrec/test_utils/__init__.py +++ b/torchrec/test_utils/__init__.py @@ -5,35 +5,68 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import ctypes import os import random import socket import time +import unittest from contextlib import closing from functools import wraps -from typing import Callable, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar import numpy as np import torch import torch.distributed as dist from pyre_extensions import ParameterSpecification +from torch import nn TParams = ParameterSpecification("TParams") TReturn = TypeVar("TReturn") def get_free_port() -> int: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - s.listen(0) - with closing(s): - return s.getsockname()[1] - except Exception as e: - raise Exception( - f"Binding failed with address 127.0.0.1 while getting free port {e}" - ) + # INTERNAL + if os.getenv("SANDCASTLE") == "1" or os.getenv("TW_JOB_USER") == "sandcastle": + if socket.has_ipv6: + family = socket.AF_INET6 + address = "localhost6" + else: + family = socket.AF_INET + address = "localhost4" + with socket.socket(family, socket.SOCK_STREAM) as s: + try: + s.bind((address, 0)) + s.listen(0) + with closing(s): + return s.getsockname()[1] + except socket.gaierror: + if address == "localhost6": + address = "::1" + else: + address = "127.0.0.1" + s.bind((address, 0)) + s.listen(0) + with closing(s): + return s.getsockname()[1] + except Exception as e: + raise Exception( + f"Binding failed with address {address} while getting free port {e}" + ) + # OSS GHA: TODO remove when enable ipv6 on GHA @omkar + else: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(0) + with closing(s): + return s.getsockname()[1] + except Exception as e: + raise Exception( + f"Binding failed with address 127.0.0.1 while getting free port {e}" + ) def is_asan() -> bool: @@ -58,8 +91,7 @@ def skip_if_asan( @wraps(func) def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]: if is_asan_or_tsan(): - print("Skipping test run since we are in ASAN mode.") - return + raise unittest.SkipTest("Skipping test run since we are in ASAN mode.") return func(*args, **kwargs) return wrapper @@ -67,8 +99,8 @@ def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]: def skip_if_asan_class(cls: TReturn) -> Optional[TReturn]: if is_asan_or_tsan(): - print("Skipping test run since we are in ASAN mode.") - return + cls.__unittest_skip__ = True + cls.__unittest_skip_why__ = "Skipping test run since we are in ASAN mode." return cls @@ -77,6 +109,8 @@ def init_distributed_single_host( ) -> dist.ProcessGroup: os.environ["LOCAL_WORLD_SIZE"] = str(local_size if local_size else world_size) os.environ["LOCAL_RANK"] = str(rank % local_size if local_size else rank) + if dist.is_initialized(): + dist.destroy_process_group() dist.init_process_group(rank=rank, world_size=world_size, backend=backend) # pyre-fixme[7]: Expected `ProcessGroup` but got # `Optional[_distributed_c10d.ProcessGroup]`. @@ -95,3 +129,46 @@ def _wrapper(*args, **kwargs): return wrapped_func(*args, **kwargs) return _wrapper + + +def get_state_buffers_parameters(model: nn.Module) -> Dict[str, Any]: + return { + "state_dict": model.state_dict(), + "named_buffers": dict(model.named_buffers()), + "named_parameters": dict(model.named_parameters()), + } + + +def assert_state_buffers_parameters_equal( + model_1: nn.Module, + model_2: nn.Module, + check_named_buffers: bool = True, + check_named_parameters: bool = True, + check_state_dict: bool = True, +) -> None: + """ + Checks to see if the keys of top level PyTorch API calls are the same + between two modules. + """ + + model_characteristics = {} + model_characteristics["model_1"] = get_state_buffers_parameters(model_1) + model_characteristics["model_2"] = get_state_buffers_parameters(model_2) + + assert ( + not check_named_buffers + or model_characteristics["model_1"]["named_buffers"].keys() + == model_characteristics["model_2"]["named_buffers"].keys() + ), "named buffers keys are not the same" + + assert ( + not check_named_parameters + or model_characteristics["model_1"]["named_parameters"].keys() + == model_characteristics["model_2"]["named_parameters"].keys() + ), f"named parameter keys are not the same {model_characteristics['model_1']['named_parameters'].keys()} vs {model_characteristics['model_2']['named_parameters'].keys()}" + + assert ( + not check_state_dict + or model_characteristics["model_1"]["state_dict"].keys() + == model_characteristics["model_2"]["state_dict"].keys() + ), f"state dict key are not the same, {model_characteristics['model_1']['state_dict'].keys()} vs {model_characteristics['model_2']['state_dict'].keys()}" diff --git a/torchrec/types.py b/torchrec/types.py index 6788cbe2a..6c14ef636 100644 --- a/torchrec/types.py +++ b/torchrec/types.py @@ -5,16 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from abc import abstractmethod +from enum import Enum, unique import torch from torch import nn +class CacheMixin: + """ + A mixin to allow modules that cache computation to clear the cache. + """ + + @abstractmethod + def clear_cache(self) -> None: ... + + class CopyMixIn: @abstractmethod - def copy(self, device: torch.device) -> nn.Module: - ... + def copy(self, device: torch.device) -> nn.Module: ... class ModuleCopyMixin(CopyMixIn): @@ -35,3 +46,26 @@ class ModuleNoCopyMixin(CopyMixIn): def copy(self, device: torch.device) -> nn.Module: # pyre-ignore [7] return self + + +# moved DataType here to avoid circular import +# TODO: organize types and dependencies +@unique +class DataType(Enum): + """ + Our fusion implementation supports only certain types of data + so it makes sense to retrict in a non-fused version as well. + """ + + FP32 = "FP32" + FP16 = "FP16" + BF16 = "BF16" + INT64 = "INT64" + INT32 = "INT32" + INT8 = "INT8" + UINT8 = "UINT8" + INT4 = "INT4" + INT2 = "INT2" + + def __str__(self) -> str: + return self.value diff --git a/version.py b/version.py deleted file mode 100644 index e4b8591d8..000000000 --- a/version.py +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Follows PEP-0440 version scheme guidelines -# https://www.python.org/dev/peps/pep-0440/#version-scheme -# -# Examples: -# 0.1.0.devN # Developmental release -# 0.1.0aN # Alpha release -# 0.1.0bN # Beta release -# 0.1.0rcN # Release Candidate -# 0.1.0 # Final release -__version__ = "0.3.01" diff --git a/version.txt b/version.txt new file mode 100644 index 000000000..ae7fb2d41 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +1.1.0a0