diff --git a/.github/scripts/filter.py b/.github/scripts/filter.py
new file mode 100644
index 000000000..bbdba868a
--- /dev/null
+++ b/.github/scripts/filter.py
@@ -0,0 +1,32 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+
+
+def main():
+ """
+ For filtering out certain cuda versions in the build matrix that
+ determines with nightly builds are run. This ensures TorchRec is
+ always consistent in version compatibility with FBGEMM.
+ """
+
+ full_matrix_string = os.environ["MAT"]
+ full_matrix = json.loads(full_matrix_string)
+
+ new_matrix_entries = []
+
+ for entry in full_matrix["include"]:
+ new_matrix_entries.append(entry)
+
+ new_matrix = {"include": new_matrix_entries}
+ print(json.dumps(new_matrix))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/install_libs.sh b/.github/scripts/install_libs.sh
new file mode 100644
index 000000000..27522ff92
--- /dev/null
+++ b/.github/scripts/install_libs.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+echo "CU_VERSION: ${CU_VERSION}"
+echo "CHANNEL: ${CHANNEL}"
+echo "CONDA_ENV: ${CONDA_ENV}"
+
+if [[ $CU_VERSION = cu* ]]; then
+ # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not
+ # being able to locate libnvrtc.so
+ echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+ conda env config vars set -p ${CONDA_ENV} \
+ LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
+else
+ echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+ conda env config vars set -p ${CONDA_ENV} \
+ LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
+fi
+
+if [ "$CHANNEL" = "nightly" ]; then
+ ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/"$CU_VERSION"
+elif [ "$CHANNEL" = "test" ]; then
+ ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION"
+fi
+
+
+${CONDA_RUN} pip install importlib-metadata
diff --git a/.github/scripts/tests_to_skip.txt b/.github/scripts/tests_to_skip.txt
new file mode 100644
index 000000000..7c0e95f00
--- /dev/null
+++ b/.github/scripts/tests_to_skip.txt
@@ -0,0 +1 @@
+_disabled_in_oss_compatibility
diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh
new file mode 100755
index 000000000..28060c868
--- /dev/null
+++ b/.github/scripts/validate_binaries.sh
@@ -0,0 +1,167 @@
+#!/usr/bin/env bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+export PYTORCH_CUDA_PKG=""
+export CONDA_ENV="build_binary"
+
+if [[ ${MATRIX_PYTHON_VERSION} = '3.13t' ]]; then
+ echo "Conda doesn't support 3.13t yet, you can just try \`conda create -n test python=3.13t\`"
+ exit 0
+fi
+
+conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}"
+
+conda run -n build_binary python --version
+
+# Install pytorch, torchrec and fbgemm as per
+# installation instructions on following page
+# https://github.com/pytorch/torchrec#installations
+
+if [[ ${MATRIX_GPU_ARCH_TYPE} = 'rocm' ]]; then
+ echo "We don't support rocm"
+ exit 0
+fi
+
+# figure out CUDA VERSION
+if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+ if [[ ${MATRIX_GPU_ARCH_VERSION} = '11.8' ]]; then
+ export CUDA_VERSION="cu118"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.1' ]]; then
+ export CUDA_VERSION="cu121"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.6' ]]; then
+ export CUDA_VERSION="cu126"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.8' ]]; then
+ export CUDA_VERSION="cu128"
+ else
+ export CUDA_VERSION="cu126"
+ fi
+else
+ export CUDA_VERSION="cpu"
+fi
+
+# figure out URL
+if [[ ${MATRIX_CHANNEL} = 'nightly' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/nightly/$%7BCUDA_VERSION%7D"
+elif [[ ${MATRIX_CHANNEL} = 'test' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/test/$%7BCUDA_VERSION%7D"
+elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/$%7BCUDA_VERSION%7D"
+fi
+
+
+echo "CU_VERSION: ${CUDA_VERSION}"
+echo "MATRIX_CHANNEL: ${MATRIX_CHANNEL}"
+echo "CONDA_ENV: ${CONDA_ENV}"
+
+# shellcheck disable=SC2155
+export CONDA_PREFIX=$(conda run -n "${CONDA_ENV}" printenv CONDA_PREFIX)
+
+
+# Set LD_LIBRARY_PATH to fix the runtime error with fbgemm_gpu not
+# being able to locate libnvrtc.so
+# NOTE: The order of the entries in LD_LIBRARY_PATH matters
+echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+conda env config vars set -n ${CONDA_ENV} \
+ LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:/usr/local/lib:/usr/lib64:${LD_LIBRARY_PATH}"
+
+
+# install pytorch
+# switch back to conda once torch nightly is fixed
+# if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}"
+# fi
+
+conda run -n "${CONDA_ENV}" pip install importlib-metadata
+
+conda run -n "${CONDA_ENV}" pip install torch --index-url "$PYTORCH_URL"
+
+# install fbgemm
+conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL"
+
+# install requirements from pypi
+conda run -n "${CONDA_ENV}" pip install torchmetrics==1.0.3
+
+# install tensordict from pypi
+conda run -n "${CONDA_ENV}" pip install tensordict==0.7.1
+
+# install torchrec
+conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL"
+
+# Run small import test
+conda run -n "${CONDA_ENV}" python -c "import torch; import fbgemm_gpu; import torchrec"
+
+# check directory
+ls -R
+
+# check if cuda available
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())"
+
+# check cuda version
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)"
+
+# Finally run smoke test
+# python 3.11 needs torchx-nightly
+conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath
+if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+ conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py
+else
+ conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --script test_installation.py -- --cpu_only
+fi
+
+
+# redo for pypi release
+
+if [[ ${MATRIX_CHANNEL} != 'release' ]]; then
+ exit 0
+else
+ # Check version matches only for release binaries
+ torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2)
+ fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2)
+
+ if [ "$torchrec_version" != "$fbgemm_version" ]; then
+ echo "Error: TorchRec package version does not match FBGEMM package version"
+ exit 1
+ fi
+fi
+
+conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}"
+
+conda run -n "${CONDA_ENV}" python --version
+
+if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.4' ]]; then
+ exit 0
+fi
+
+echo "checking pypi release"
+conda run -n "${CONDA_ENV}" pip install torch
+conda run -n "${CONDA_ENV}" pip install fbgemm-gpu
+conda run -n "${CONDA_ENV}" pip install torchrec
+
+# Check version matching again for PyPI
+torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2)
+fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2)
+
+if [ "$torchrec_version" != "$fbgemm_version" ]; then
+ echo "Error: TorchRec package version does not match FBGEMM package version"
+ exit 1
+fi
+
+# check directory
+ls -R
+
+# check if cuda available
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())"
+
+# check cuda version
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)"
+
+# python 3.11 needs torchx-nightly
+conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath
+
+# Finally run smoke test
+conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py
diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml
new file mode 100644
index 000000000..b65648ac2
--- /dev/null
+++ b/.github/workflows/build-wheels-linux.yml
@@ -0,0 +1,64 @@
+name: Build Linux Wheels
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ tags:
+ # Release candidate tags look like: v1.11.0-rc1
+ - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
+ workflow_dispatch:
+
+permissions:
+ id-token: write
+ contents: read
+
+jobs:
+ generate-matrix:
+ uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
+ with:
+ package-type: wheel
+ os: linux
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ with-rocm: false
+ filter-matrix:
+ needs: generate-matrix
+ runs-on: linux.20_04.4x
+ outputs:
+ matrix: ${{ steps.filter.outputs.matrix }}
+ steps:
+ - uses: actions/setup-python@v4
+ - name: Checkout torchrec repository
+ uses: actions/checkout@v4
+ with:
+ repository: pytorch/torchrec
+ - name: Filter Generated Built Matrix
+ id: filter
+ env:
+ MAT: ${{ needs.generate-matrix.outputs.matrix }}
+ run: |
+ set -ex
+ pwd
+ ls
+ MATRIX_BLOB="$(python .github/scripts/filter.py)"
+ echo "${MATRIX_BLOB}"
+ echo "matrix=${MATRIX_BLOB}" >> "${GITHUB_OUTPUT}"
+ build:
+ needs: filter-matrix
+ name: pytorch/torchrec
+ uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
+ with:
+ repository: pytorch/torchrec
+ ref: ""
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
+ pre-script: ""
+ post-script: .github/scripts/install_libs.sh
+ package-name: torchrec
+ smoke-test-script: ""
+ trigger-event: ${{ github.event_name }}
diff --git a/.github/workflows/build_dynamic_embedding_wheels.yml b/.github/workflows/build_dynamic_embedding_wheels.yml
index cdf4d5f76..da8174812 100644
--- a/.github/workflows/build_dynamic_embedding_wheels.yml
+++ b/.github/workflows/build_dynamic_embedding_wheels.yml
@@ -20,27 +20,52 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
- pyver: [ cp37, cp38, cp39, cp310 ]
- cuver: [ "11.6", "11.3"]
+ pyver: [ cp39, cp310, cp311, cp312 ]
+ cuver: [ "12.1", "12.4"]
steps:
- - uses: actions/checkout@v3
+ -
+ name: Check disk space
+ run: df . -h
+
+ - name: Remove unnecessary files
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+
+ -
+ name: Check disk space
+ run: df . -h
+
+ - uses: actions/checkout@v4
with:
submodules: recursive
- - uses: pypa/cibuildwheel@v2.8.0
- with:
+ - uses: pypa/cibuildwheel@v2.20.0
+ with:
package-dir: contrib/dynamic_embedding
env:
CIBW_BEFORE_BUILD: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/before_linux_build.sh"
CIBW_BUILD: "${{ matrix.pyver }}-manylinux_x86_64"
CIBW_REPAIR_WHEEL_COMMAND: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/repair_wheel.sh {wheel} {dest_dir}"
+ CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28"
- name: Verify clean directory
run: git diff --exit-code
shell: bash
- name: Upload wheels
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
+ name: artifact-${{ matrix.os }}-${{ matrix.pyver }}-cu${{ matrix.cuver }}
path: wheelhouse/*.whl
+
+ merge:
+ runs-on: ubuntu-latest
+ needs: build_wheels
+ steps:
+ - name: Merge Artifacts
+ uses: actions/upload-artifact/merge@v4
+ with:
+ name: artifact
+ pattern: artifact-*
diff --git a/.github/workflows/cpp_unittest_ci_cpu.yml b/.github/workflows/cpp_unittest_ci_cpu.yml
new file mode 100644
index 000000000..ad9e6e13f
--- /dev/null
+++ b/.github/workflows/cpp_unittest_ci_cpu.yml
@@ -0,0 +1,63 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: CPU Unit Test C++ CI
+
+on:
+ push:
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
+ pull_request:
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
+
+jobs:
+ build_test:
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - os: linux.2xlarge
+ python-version: 3.9
+ python-tag: "py39"
+ - os: linux.2xlarge
+ python-version: '3.10'
+ python-tag: "py310"
+ - os: linux.2xlarge
+ python-version: '3.11'
+ python-tag: "py311"
+ - os: linux.2xlarge
+ python-version: '3.12'
+ python-tag: "py312"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
+ with:
+ runner: ${{ matrix.os }}
+ timeout: 15
+ script: |
+ ldd --version
+ conda create -y --name build_binary python=${{ matrix.python-version }}
+ conda info
+ python --version
+ echo "Starting C++ Tests"
+ conda install -n build_binary -y gxx_linux-64
+ conda run -n build_binary \
+ x86_64-conda-linux-gnu-g++ --version
+ conda install -n build_binary -c anaconda redis -y
+ conda run -n build_binary redis-server --daemonize yes
+ mkdir cpp-build
+ cd cpp-build
+ conda run -n build_binary cmake \
+ -DBUILD_TEST=ON \
+ -DBUILD_REDIS_IO=ON \
+ -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake ..
+ conda run -n build_binary make -j
+ conda run -n build_binary ctest -V .
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index c69be055e..360668d4b 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -5,24 +5,30 @@ on:
branches:
- main
workflow_dispatch:
+ pull_request:
+
jobs:
build_docs_job:
runs-on: ${{ matrix.os }}
+ permissions:
+ # Grant write permission here so that the doc can be pushed to gh-pages branch
+ contents: write
strategy:
matrix:
include:
- - os: linux.2xlarge
- python-version: 3.7
+ - os: linux.20_04.4x
+ python-version: 3.9
+ python-tag: "py39"
steps:
- name: Check ldd --version
run: ldd --version
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
# Update references
- name: Update pip
run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
+ sudo apt-get update
+ sudo apt-get -y install python3-pip
sudo pip3 install --upgrade pip
- name: Setup conda
run: |
@@ -45,17 +51,24 @@ jobs:
- name: Install gcc
shell: bash
run: |
- sudo yum group install -y "Development Tools"
+ sudo apt-get install build-essential
- name: setup Path
run: |
echo /usr/local/bin >> $GITHUB_PATH
- name: Install PyTorch
shell: bash
run: |
- conda run -n build_binary python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
+ conda run -n build_binary pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+ - name: Install fbgemm
+ run: |
+ conda run -n build_binary pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
+ - name: Install torchmetrics
+ run: |
+ conda run -n build_binary pip install torchmetrics==1.0.3
- name: Install TorchRec
run: |
- conda run -n build_binary python -m pip install --pre torchrec_nightly_cpu -f https://download.pytorch.org/whl/nightly/torchrec_nightly_cpu/index.html
+ conda run -n build_binary pip install -r requirements.txt
+ conda run -n build_binary python setup.py bdist_wheel --python-tag=${{ matrix.python-tag }}
- name: Test fbgemm_gpu and torchrec installation
shell: bash
run: |
@@ -69,11 +82,42 @@ jobs:
cd ./docs
conda run -n build_binary make html
cd ..
+ - name: Upload Built-Docs
+ uses: actions/upload-artifact@v4
+ with:
+ name: Built-Docs
+ path: docs/build/html/
- name: Get output time
run: echo "The time was ${{ steps.build.outputs.time }}"
- name: Deploy
+ if: github.ref == 'refs/heads/main'
uses: JamesIves/github-pages-deploy-action@releases/v3
with:
ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH: gh-pages # The branch the action should deploy to.
FOLDER: docs/build/html # The folder the action should deploy.
+
+ doc-preview:
+ runs-on: [linux.2xlarge]
+ needs: build_docs_job
+ if: ${{ github.event_name == 'pull_request' }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Download artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: Built-Docs
+ path: docs
+ - name: Add no-index tag
+ run: |
+ find docs -name "*.html" -print0 | xargs -0 sed -i '/
/a \ \ ';
+ - name: Upload docs preview
+ uses: seemethere/upload-artifact-s3@v5
+ if: ${{ github.event_name == 'pull_request' }}
+ with:
+ retention-days: 14
+ s3-bucket: doc-previews
+ if-no-files-found: error
+ path: docs
+ s3-prefix: pytorch/torchrec/${{ github.event.pull_request.number }}
diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml
deleted file mode 100644
index be58565f1..000000000
--- a/.github/workflows/nightly_build.yml
+++ /dev/null
@@ -1,207 +0,0 @@
-# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
-# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
-
-name: Push Binary Nightly
-
-on:
- workflow_call:
- secrets:
- AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID:
- required: true
- AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY:
- required: true
- PYPI_TOKEN:
- required: false
- # run every day at 11:15am
- schedule:
- - cron: '15 11 * * *'
- # or manually trigger it
- workflow_dispatch:
-
-jobs:
- # build on cpu hosts and upload to GHA
- build_on_cpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.9
- python-tag: "py39"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: '3.10'
- python-tag: "py310"
- cuda-tag: "cu11"
- steps:
- # Checkout the repository to the GitHub Actions runner
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda -u
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- - name: Install Dependencies
- shell: bash
- run: |
- conda run -n build_binary python -m pip install -r requirements.txt
- - name: Test Installation of dependencies
- run: |
- conda run -n build_binary python -c "import torch.distributed"
- echo "torch.distributed succeeded"
- conda run -n build_binary python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary python -c "import numpy"
- echo "numpy succeeded"
- # for the conda run with quotes, we have to use "\" and double quotes
- # here is the issue: https://github.com/conda/conda/issues/10972
- - name: Build TorchRec Nightly
- run: |
- rm -r dist || true
- conda run -n build_binary \
- python setup.py bdist_wheel \
- --package_name torchrec-nightly \
- --python-tag=${{ matrix.python-tag }}
- - name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
- with:
- name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- path: dist/torchrec_nightly-*.whl
-
- # download from GHA, test on gpu and push to pypi
- test_on_gpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7, 3.8, 3.9]
- cuda-tag: ["cu11"]
- needs: build_on_cpu
- # the glibc version should match the version of the one we used to build the binary
- # for this case, it's 2.26
- steps:
- - name: Check ldd --version
- run: ldd --version
- - name: check cpu info
- shell: bash
- run: |
- cat /proc/cpuinfo
- - name: check distribution info
- shell: bash
- run: |
- cat /proc/version
- - name: Display EC2 information
- shell: bash
- run: |
- set -euo pipefail
- function get_ec2_metadata() {
- # Pulled from instance metadata endpoint for EC2
- # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
- category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
- }
- echo "ami-id: $(get_ec2_metadata ami-id)"
- echo "instance-id: $(get_ec2_metadata instance-id)"
- echo "instance-type: $(get_ec2_metadata instance-type)"
- - name: check gpu info
- shell: bash
- run: |
- sudo yum install lshw -y
- sudo lshw -C display
- # Checkout the repository to the GitHub Actions runner
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- # download wheel from GHA
- - name: Download wheel
- uses: actions/download-artifact@v2
- with:
- name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- - name: Display structure of downloaded files
- run: ls -R
- - name: Install TorchRec Nightly
- run: |
- rm -r dist || true
- conda run -n build_binary python -m pip install *.whl
- - name: Test torchrec installation
- shell: bash
- run: |
- conda run -n build_binary \
- python -c "import torchrec"
- - name: Push TorchRec Binary to PYPI
- env:
- PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
- run: |
- conda run -n build_binary python -m pip install twine
- conda run -n build_binary \
- python -m twine upload \
- --username __token__ \
- --password "$PYPI_TOKEN" \
- --skip-existing \
- torchrec_nightly-*.whl \
- --verbose
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
index eff5ce19b..74a5af78c 100644
--- a/.github/workflows/pre-commit.yaml
+++ b/.github/workflows/pre-commit.yaml
@@ -10,15 +10,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Setup Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
- python-version: 3.8
+ python-version: 3.9
architecture: x64
packages: |
- ufmt==1.3.2
- black==22.3.0
- usort==1.0.2
+ ufmt==2.5.1
+ black==24.2.0
+ usort==1.0.8
- name: Checkout Torchrec
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Run pre-commit
- uses: pre-commit/action@v2.0.3
+ uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml
new file mode 100644
index 000000000..fd773c787
--- /dev/null
+++ b/.github/workflows/pyre.yml
@@ -0,0 +1,27 @@
+name: Pyre Check
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+
+jobs:
+ pyre-check:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -el {0}
+ steps:
+ - uses: conda-incubator/setup-miniconda@v2
+ with:
+ python-version: 3.9
+ - name: Checkout Torchrec
+ uses: actions/checkout@v4
+ - name: Install dependencies
+ run: >
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu &&
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu &&
+ pip install -r requirements.txt &&
+ pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g')
+ - name: Pyre check
+ run: pyre check
diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml
index 58c7c55fb..1ea837d4b 100644
--- a/.github/workflows/release_build.yml
+++ b/.github/workflows/release_build.yml
@@ -6,10 +6,6 @@ name: Push Binary Release
on:
workflow_call:
secrets:
- AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID:
- required: true
- AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY:
- required: true
PYPI_TOKEN:
required: false
workflow_dispatch:
@@ -22,33 +18,32 @@ jobs:
strategy:
matrix:
include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- cuda-tag: "cu11"
- os: linux.2xlarge
python-version: 3.9
python-tag: "py39"
- cuda-tag: "cu11"
+ cuda-tag: "cu124"
- os: linux.2xlarge
python-version: '3.10'
python-tag: "py310"
- cuda-tag: "cu11"
+ cuda-tag: "cu124"
+ - os: linux.2xlarge
+ python-version: '3.11'
+ python-tag: "py311"
+ cuda-tag: "cu124"
+ - os: linux.2xlarge
+ python-version: '3.12'
+ python-tag: "py312"
+ cuda-tag: "cu124"
steps:
# Checkout the repository to the GitHub Actions runner
- name: Check ldd --version
run: ldd --version
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Update pip
run: |
sudo yum update -y
sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- name: Setup conda
run: |
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
@@ -73,7 +68,12 @@ jobs:
- name: Install PyTorch and CUDA
shell: bash
run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia
+ conda run -n build_binary pip install torch
+ - name: Install fbgemm
+ shell: bash
+ run: |
+ conda run -n build_binary pip install numpy
+ conda run -n build_binary pip install fbgemm-gpu
- name: Install Dependencies
shell: bash
run: |
@@ -89,26 +89,27 @@ jobs:
# for the conda run with quotes, we have to use "\" and double quotes
# here is the issue: https://github.com/conda/conda/issues/10972
- name: Build TorchRec
+ env:
+ OFFICIAL_RELEASE: 1
run: |
rm -r dist || true
conda run -n build_binary \
python setup.py bdist_wheel \
- --package_name torchrec \
--python-tag=${{ matrix.python-tag }}
- name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v4
with:
name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
path: dist/torchrec-*.whl
- # download from GHA, test on gpu and push to pypi
- test_on_gpu:
+ # download from GHA, sanity check on gpu and push to pypi
+ sanity_check_on_gpu_and_push:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7, 3.8, 3.9]
- cuda-tag: ["cu11"]
+ os: [linux.g5.12xlarge.nvidia.gpu]
+ python-version: [3.9, "3.10", "3.11", "3.12"]
+ cuda-tag: ["cu124"]
needs: build_on_cpu
# the glibc version should match the version of the one we used to build the binary
# for this case, it's 2.26
@@ -131,7 +132,7 @@ jobs:
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
+ curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
@@ -143,12 +144,11 @@ jobs:
sudo lshw -C display
# Checkout the repository to the GitHub Actions runner
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Update pip
run: |
sudo yum update -y
sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- name: Setup conda
run: |
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
@@ -173,10 +173,19 @@ jobs:
- name: Install PyTorch and CUDA
shell: bash
run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia
+ conda run -n build_binary pip install torch
# download wheel from GHA
+ - name: Install fbgemm
+ shell: bash
+ run: |
+ conda run -n build_binary pip install numpy
+ conda run -n build_binary pip install fbgemm-gpu
+ - name: Install torchmetrics
+ shell: bash
+ run: |
+ conda run -n build_binary pip install torchmetrics==1.0.3
- name: Download wheel
- uses: actions/download-artifact@v2
+ uses: actions/download-artifact@v4
with:
name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- name: Display structure of downloaded files
@@ -189,13 +198,9 @@ jobs:
shell: bash
run: |
conda run -n build_binary \
- python -c "import torchrec"
- - name: Test with pytest
- run: |
- conda run -n build_binary \
- python -m pip install pytest
+ python -c "import fbgemm_gpu"
conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors
+ python -c "import torchrec"
# Push to Pypi
- name: Push TorchRec Binary to PYPI
env:
diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml
index 82cbf5621..5ea25f9d4 100644
--- a/.github/workflows/unittest_ci.yml
+++ b/.github/workflows/unittest_ci.yml
@@ -4,188 +4,120 @@
name: Unit Test CI
on:
- # TODO: re-enable when GPU unit tests are working
- # push:
- # paths-ignore:
- # - "docs/*"
- # - "third_party/*"
- # - .gitignore
- # - "*.md"
- # pull_request:
- # paths-ignore:
- # - "docs/*"
- # - "third_party/*"
- # - .gitignore
- # - "*.md"
+ push:
+ branches:
+ - nightly
+ - main
workflow_dispatch:
jobs:
- # build on cpu hosts and upload to GHA
- build_on_cpu:
- runs-on: ${{ matrix.os }}
+ build_test:
strategy:
+ fail-fast: false
matrix:
include:
- - os: linux.2xlarge
- # ideally we run on 3.8 and 3.9 as well, however we are limited in resources.
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- steps:
- # Checkout the repository to the GitHub Actions runner
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda -u
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: 3.9
+ python-tag: "py39"
+ cuda-tag: "cu118"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: 3.9
+ python-tag: "py39"
+ cuda-tag: "cu121"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: 3.9
+ python-tag: "py39"
+ cuda-tag: "cu124"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.10'
+ python-tag: "py310"
+ cuda-tag: "cu118"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.10'
+ python-tag: "py310"
+ cuda-tag: "cu121"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.10'
+ python-tag: "py310"
+ cuda-tag: "cu124"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.11'
+ python-tag: "py311"
+ cuda-tag: "cu118"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.11'
+ python-tag: "py311"
+ cuda-tag: "cu121"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.11'
+ python-tag: "py311"
+ cuda-tag: "cu124"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.12'
+ python-tag: "py312"
+ cuda-tag: "cu118"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.12'
+ python-tag: "py312"
+ cuda-tag: "cu121"
+ - os: linux.g5.12xlarge.nvidia.gpu
+ python-version: '3.12'
+ python-tag: "py312"
+ cuda-tag: "cu124"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
+ with:
+ runner: ${{ matrix.os }}
+ timeout: 30
+ script: |
+ ldd --version
+ conda create -y --name build_binary python=${{ matrix.python-version }}
conda info
- - name: check python version no Conda
- run: |
python --version
- - name: check python version
- run: |
conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- - name: Install Dependencies
- shell: bash
- run: |
- conda run -n build_binary python -m pip install -r requirements.txt
- - name: Test Installation of dependencies
- run: |
- conda run -n build_binary python -c "import torch.distributed"
- echo "torch.distributed succeeded"
- conda run -n build_binary python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary python -c "import numpy"
- echo "numpy succeeded"
- # for the conda run with quotes, we have to use "\" and double quotes
- # here is the issue: https://github.com/conda/conda/issues/10972
- - name: Build TorchRec Binary
- run: |
+ conda run -n build_binary \
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
+ conda run -n build_binary \
+ python -c "import torch"
+ echo "torch succeeded"
+ conda run -n build_binary \
+ python -c "import torch.distributed"
+ conda run -n build_binary \
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
+ conda run -n build_binary \
+ python -c "import fbgemm_gpu"
+ echo "fbgemm_gpu succeeded"
+ conda run -n build_binary \
+ pip install -r requirements.txt
conda run -n build_binary \
python setup.py bdist_wheel \
- --package_name torchrec-test \
--python-tag=${{ matrix.python-tag }}
- - name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
- with:
- name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- path: dist/torchrec-test-*.whl
-
- # download from GHA, test on gpu
- test_on_gpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7]
- cuda-tag: ["cu11"]
- needs: build_on_cpu
- # the glibc version should match the version of the one we used to build the binary
- # for this case, it's 2.26
- steps:
- - name: Check ldd --version
- # Run unit tests
- run: ldd --version
- - name: check cpu info
- shell: bash
- run: |
- cat /proc/cpuinfo
- - name: check distribution info
- shell: bash
- run: |
- cat /proc/version
- - name: Display EC2 information
- shell: bash
- run: |
- set -euo pipefail
- function get_ec2_metadata() {
- # Pulled from instance metadata endpoint for EC2
- # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
- category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
- }
- echo "ami-id: $(get_ec2_metadata ami-id)"
- echo "instance-id: $(get_ec2_metadata instance-id)"
- echo "instance-type: $(get_ec2_metadata instance-type)"
- - name: check gpu info
- shell: bash
- run: |
- sudo yum install lshw -y
- sudo lshw -C display
- # Checkout the repository to the GitHub Actions runner
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- # download wheel from GHA
- - name: Download wheel
- uses: actions/download-artifact@v2
- with:
- name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- - name: Display structure of downloaded files
- run: ls -R
- - name: Install TorchRec GPU
- run: |
- rm -r dist || true
- conda run -n build_binary python -m pip install dist/*.whl
- - name: Test torchrec installation
- shell: bash
- run: |
conda run -n build_binary \
python -c "import torchrec"
- - name: Test with pytest
- run: |
+ echo "torch.distributed succeeded"
conda run -n build_binary \
- python -m pip install pytest
+ python -c "import numpy"
+ echo "numpy succeeded"
+ conda install -n build_binary -y pytest
+ # Read the list of tests to skip from a file, ignoring empty lines and comments
+ skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
+ # Check if skip_expression is effectively empty
+ if [ -z "$skip_expression" ]; then
+ skip_expression=""
+ else
+ skip_expression=${skip_expression:5} # Remove the leading " and "
+ fi
conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors
+ python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
+ --ignore=torchrec/distributed/tests/test_comm.py --ignore=torchrec/distributed/tests/test_infer_shardings.py \
+ --ignore=torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py --ignore=torchrec/distributed/tests/test_pt2_multiprocess.py \
+ --ignore=torchrec/distributed/tests/test_pt2.py --ignore=torchrec/distributed/tests/test_quant_model_parallel.py \
+ --ignore=torchrec/distributed/tests/test_quant_pruning.py --ignore=torchrec/distributed/tests/test_quant_sequence_model_parallel.py \
+ --ignore-glob='torchrec/metrics/*' --ignore-glob='torchrec/distributed/tests/test_model_parallel_gloo*' \
+ --ignore-glob='torchrec/inference/inference_legacy/tests*' --ignore-glob='*test_model_parallel_nccl*' \
+ --ignore=torchrec/distributed/tests/test_cache_prefetch.py --ignore=torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py \
+ --ignore=torchrec/distributed/tests/test_infer_utils.py --ignore=torchrec/distributed/tests/test_fx_jit.py --ignore-glob=**/test_utils/ \
+ --ignore-glob='*test_train_pipeline*' --ignore=torchrec/distributed/tests/test_model_parallel_hierarchical.py \
+ -k "$skip_expression"
diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml
index e69c4f72f..0c53bbe7d 100644
--- a/.github/workflows/unittest_ci_cpu.yml
+++ b/.github/workflows/unittest_ci_cpu.yml
@@ -20,70 +20,66 @@ on:
jobs:
build_test:
strategy:
+ fail-fast: false
matrix:
include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- os: linux.2xlarge
python-version: 3.9
python-tag: "py39"
- os: linux.2xlarge
python-version: '3.10'
python-tag: "py310"
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ - os: linux.2xlarge
+ python-version: '3.11'
+ python-tag: "py311"
+ - os: linux.2xlarge
+ python-version: '3.12'
+ python-tag: "py312"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: ${{ matrix.os }}
- timeout: 30
+ timeout: 15
script: |
ldd --version
conda create -y --name build_binary python=${{ matrix.python-version }}
conda info
python --version
conda run -n build_binary python --version
- conda install -n build_binary \
- --yes \
- -c pytorch-nightly \
- "pytorch-nightly"::pytorch[build="*cpu*"]
conda run -n build_binary \
- pip install -r requirements.txt
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
conda run -n build_binary \
- pip uninstall fbgemm_gpu-nightly -y
- conda run -n build_binary \
- pip install fbgemm-gpu-nightly-cpu
+ python -c "import torch"
+ echo "torch succeeded"
conda run -n build_binary \
python -c "import torch.distributed"
- echo "torch.distributed succeeded"
conda run -n build_binary \
- python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary \
- python -c "import numpy"
- echo "numpy succeeded"
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
conda run -n build_binary \
python -c "import fbgemm_gpu"
echo "fbgemm_gpu succeeded"
-
+ conda run -n build_binary \
+ pip install -r requirements.txt
conda run -n build_binary \
python setup.py bdist_wheel \
- --package_name torchrec-test-cpu \
--python-tag=${{ matrix.python-tag }}
conda run -n build_binary \
python -c "import torchrec"
- conda install -n build_binary -y pytest
+ echo "torch.distributed succeeded"
conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors -k 'not test_sharding_gloo_cw'
- echo "Starting C++ Tests"
- conda install -n build_binary -y gxx_linux-64
+ python -c "import numpy"
+ echo "numpy succeeded"
+ conda install -n build_binary -y pytest
+ # Read the list of tests to skip from a file, ignoring empty lines and comments
+ skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
+ # Check if skip_expression is effectively empty
+ if [ -z "$skip_expression" ]; then
+ skip_expression=""
+ else
+ skip_expression=${skip_expression:5} # Remove the leading " and "
+ fi
conda run -n build_binary \
- x86_64-conda-linux-gnu-g++ --version
- mkdir cpp-build
- cd cpp-build
- conda run -n build_binary cmake \
- -DBUILD_TEST=ON \
- -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake ..
- conda run -n build_binary make -j
- conda run -n build_binary ctest -v .
+ python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
+ --ignore-glob=**/test_utils/ -k "$skip_expression"
diff --git a/.github/workflows/validate-binaries.yml b/.github/workflows/validate-binaries.yml
new file mode 100644
index 000000000..248857214
--- /dev/null
+++ b/.github/workflows/validate-binaries.yml
@@ -0,0 +1,41 @@
+name: Validate binaries
+
+on:
+ workflow_call:
+ inputs:
+ channel:
+ description: "Channel to use (nightly, release)"
+ required: false
+ type: string
+ default: release
+ ref:
+ description: 'Reference to checkout, defaults to empty'
+ default: ""
+ required: false
+ type: string
+ workflow_dispatch:
+ inputs:
+ channel:
+ description: "Channel to use (nightly, release, test, pypi)"
+ required: true
+ type: choice
+ options:
+ - release
+ - nightly
+ - test
+ ref:
+ description: 'Reference to checkout, defaults to empty'
+ default: ""
+ required: false
+ type: string
+
+jobs:
+ validate-binaries:
+ uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main
+ with:
+ package_type: "wheel"
+ os: "linux"
+ channel: ${{ inputs.channel }}
+ repository: "pytorch/torchrec"
+ smoke_test: "source ./.github/scripts/validate_binaries.sh"
+ with_cuda: enable
diff --git a/.github/workflows/validate-nightly-binaries.yml b/.github/workflows/validate-nightly-binaries.yml
new file mode 100644
index 000000000..6d6369495
--- /dev/null
+++ b/.github/workflows/validate-nightly-binaries.yml
@@ -0,0 +1,26 @@
+# Scheduled validation of the nightly binaries
+name: validate-nightly-binaries
+
+on:
+ schedule:
+ # At 5:30 pm UTC (7:30 am PDT)
+ - cron: "30 17 * * *"
+ # Have the ability to trigger this job manually through the API
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - .github/workflows/validate-nightly-binaries.yml
+ - .github/workflows/validate-binaries.yml
+ - .github/scripts/validate-binaries.sh
+ pull_request:
+ paths:
+ - .github/workflows/validate-nightly-binaries.yml
+ - .github/workflows/validate-binaries.yml
+ - .github/scripts/validate-binaries.sh
+jobs:
+ nightly:
+ uses: ./.github/workflows/validate-binaries.yml
+ with:
+ channel: nightly
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 23a76193a..8fd1bbc1f 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -11,7 +11,7 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
- 'black==22.3.0',
+ 'black==24.2.0',
]
is_formatter = true
@@ -28,6 +28,6 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
- 'usort==1.0.2',
+ 'usort==1.0.8',
]
is_formatter = true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ea648c6ef..705694ea9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.0.1
+ rev: v4.5.0
hooks:
- id: check-toml
- id: check-yaml
@@ -8,9 +8,9 @@ repos:
- id: end-of-file-fixer
- repo: https://github.com/omnilib/ufmt
- rev: v1.3.2
+ rev: v2.5.1
hooks:
- id: ufmt
additional_dependencies:
- - black == 22.3.0
- - usort == 1.0.2
+ - black == 24.2.0
+ - usort == 1.0.8.post1
diff --git a/.pyre_configuration b/.pyre_configuration
new file mode 100644
index 000000000..4249f9c02
--- /dev/null
+++ b/.pyre_configuration
@@ -0,0 +1,17 @@
+{
+ "exclude": [
+ ".*/pyre-check/stubs/.*",
+ ".*/torchrec/datasets*",
+ ".*/torchrec/models*",
+ ".*/torchrec/inference/client.py"
+ ],
+ "site_package_search_strategy": "all",
+ "source_directories": [
+ {
+ "import_root": ".",
+ "source": "torchrec"
+ }
+ ],
+ "strict": true,
+ "version": "0.0.101729681899"
+}
diff --git a/README.MD b/README.MD
index 69c40d6ca..44fc026f6 100644
--- a/README.MD
+++ b/README.MD
@@ -1,85 +1,98 @@
-# TorchRec (Beta Release)
-[Docs](https://pytorch.org/torchrec/)
+# TorchRec
-TorchRec is a PyTorch domain library built to provide common sparsity & parallelism primitives needed for large-scale recommender systems (RecSys). It allows authors to train models with large embedding tables sharded across many GPUs.
+**TorchRec** is a PyTorch domain library built to provide common sparsity and parallelism primitives needed for large-scale recommender systems (RecSys). TorchRec allows training and inference of models with large embedding tables sharded across many GPUs and **powers many production RecSys models at Meta**.
-## TorchRec contains:
+## External Presence
+TorchRec has been used to accelerate advancements in recommendation systems, some examples:
+* [Latest version of Meta's DLRM (Deep Learning Recommendation Model)](https://github.com/facebookresearch/dlrm) is built using TorchRec
+* [Disaggregated Multi-Tower: Topology-aware Modeling Technique for Efficient Large-Scale Recommendation](https://arxiv.org/abs/2403.00877) paper
+* [The Algorithm ML](https://github.com/twitter/the-algorithm-ml) from Twitter
+* [Training Recommendation Models with Databricks](https://docs.databricks.com/en/machine-learning/train-recommender-models.html)
+* [Toward 100TB model with Embedding Offloading Paper](https://dl.acm.org/doi/10.1145/3640457.3688037)
+
+
+## Introduction
+
+To begin learning about TorchRec, check out:
+* Our complete [TorchRec Tutorial](https://pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
+* The [TorchRec documentation](https://pytorch.org/torchrec/) for an overview of TorchRec and API references
+
+
+### TorchRec Features
- Parallelism primitives that enable easy authoring of large, performant multi-device/multi-node models using hybrid data-parallelism/model-parallelism.
-- The TorchRec sharder can shard embedding tables with different sharding strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, and column-wise sharding.
-- The TorchRec planner can automatically generate optimized sharding plans for models.
-- Pipelined training overlaps dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance.
-- Optimized kernels for RecSys powered by FBGEMM.
-- Quantization support for reduced precision training and inference.
+- Sharders to shard embedding tables with different strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, column-wise, and table-wise-column-wise sharding.
+- Planner that can automatically generate optimized sharding plans for models.
+- Pipelined training overlapping dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance.
+- Optimized kernels for RecSys powered by [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main).
+- Quantization support for reduced precision training and inference, along with optimizing a TorchRec model for C++ inference.
- Common modules for RecSys.
-- Production-proven model architectures for RecSys.
- RecSys datasets (criteo click logs and movielens)
- Examples of end-to-end training such the dlrm event prediction model trained on criteo click logs dataset.
-# Installation
-Torchrec requires Python >= 3.7 and CUDA >= 11.0 (CUDA is highly recommended for performance but not required). The example below shows how to install with CUDA 11.6. This setup assumes you have conda installed.
+## Installation
-## Binaries
+Check out the [Getting Started](https://pytorch.org/torchrec/setup-torchrec.html) section in the documentation for recommended ways to set up Torchrec.
-Experimental binary on Linux for Python 3.7, 3.8 and 3.9 can be installed via pip wheels
+### From Source
-### Installations
-```
-TO use the library without cuda, use the *-cpu fbgemm installations. However, this will be much slower than the CUDA variant.
+**Generally, there isn't a need to build from source**. For most use cases, follow the section above to set up TorchRec. However, to build from source and to get the latest changes, do the following:
-Nightly
+1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/).
+ ```
+ CUDA 12.4
-conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
-pip install torchrec_nightly
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu124
-Stable
+ CUDA 12.1
-conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install torchrec
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
-If you have no CUDA device:
+ CUDA 11.8
-Nightly
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118
-pip uninstall fbgemm-gpu-nightly -y
-pip install fbgemm-gpu-nightly-cpu
+ CPU
-Stable
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+ ```
-pip uninstall fbgemm-gpu -y
-pip install fbgemm-gpu-cpu
+2. Clone TorchRec.
+ ```
+ git clone --recursive https://github.com/pytorch/torchrec
+ cd torchrec
+ ```
-```
+3. Install FBGEMM.
+ ```
+ CUDA 12.4
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu124
-### Colab example: introduction + install
-See our colab notebook for an introduction to torchrec which includes runnable installation.
- - [Tutorial Source](https://github.com/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb)
- - Open in [Google Colab](https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb)
+ CUDA 12.1
-## From Source
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121
-We are currently iterating on the setup experience. For now, we provide manual instructions on how to build from source. The example below shows how to install with CUDA 11.3. This setup assumes you have conda installed.
+ CUDA 11.8
-1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/)
- ```
- conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu118
+
+ CPU
+
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
```
-2. Install Requirements
+4. Install other requirements.
```
pip install -r requirements.txt
```
-3. Download and install TorchRec.
+4. Install TorchRec.
```
- git clone --recursive https://github.com/pytorch/torchrec
-
- cd torchrec
python setup.py install develop
```
-4. Test the installation.
+5. Test the installation (use torchx-nightly for 3.11; for 3.12, torchx currently doesn't work).
```
GPU mode
@@ -93,5 +106,32 @@ We are currently iterating on the setup experience. For now, we provide manual i
5. If you want to run a more complex example, please take a look at the torchrec [DLRM example](https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/dlrm_main.py).
+## Contributing
+
+See [CONTRIBUTING.md](https://github.com/pytorch/torchrec/blob/main/CONTRIBUTING.md) for details about contributing to TorchRec!
+
+## Citation
+
+If you're using TorchRec, please refer to BibTeX entry to cite this work:
+```
+@inproceedings{10.1145/3523227.3547387,
+author = {Ivchenko, Dmytro and Van Der Staay, Dennis and Taylor, Colin and Liu, Xing and Feng, Will and Kindi, Rahul and Sudarshan, Anirudh and Sefati, Shahin},
+title = {TorchRec: a PyTorch Domain Library for Recommendation Systems},
+year = {2022},
+isbn = {9781450392785},
+publisher = {Association for Computing Machinery},
+address = {New York, NY, USA},
+url = {https://doi.org/10.1145/3523227.3547387},
+doi = {10.1145/3523227.3547387},
+abstract = {Recommendation Systems (RecSys) comprise a large footprint of production-deployed AI today. The neural network-based recommender systems differ from deep learning models in other domains in using high-cardinality categorical sparse features that require large embedding tables to be trained. In this talk we introduce TorchRec, a PyTorch domain library for Recommendation Systems. This new library provides common sparsity and parallelism primitives, enabling researchers to build state-of-the-art personalization models and deploy them in production. In this talk we cover the building blocks of the TorchRec library including modeling primitives such as embedding bags and jagged tensors, optimized recommender system kernels powered by FBGEMM, a flexible sharder that supports a veriety of strategies for partitioning embedding tables, a planner that automatically generates optimized and performant sharding plans, support for GPU inference and common modeling modules for building recommender system models. TorchRec library is currently used to train large-scale recommender models at Meta. We will present how TorchRec helped Meta’s recommender system platform to transition from CPU asynchronous training to accelerator-based full-sync training.},
+booktitle = {Proceedings of the 16th ACM Conference on Recommender Systems},
+pages = {482–483},
+numpages = {2},
+keywords = {information retrieval, recommender systems},
+location = {Seattle, WA, USA},
+series = {RecSys '22}
+}
+```
+
## License
TorchRec is BSD licensed, as found in the [LICENSE](LICENSE) file.
diff --git a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb
new file mode 100644
index 000000000..015d216e4
--- /dev/null
+++ b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb
@@ -0,0 +1,2972 @@
+{
+ "metadata": {
+ "custom": {
+ "cells": [],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "background_execution": "on",
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "name": "Torchrec Introduction.ipynb",
+ "provenance": []
+ },
+ "fileHeader": "",
+ "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
+ "interpreter": {
+ "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
+ },
+ "isAdHoc": false,
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('torchrec': conda)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+ },
+ "indentAmount": 2,
+ "last_server_session_id": "e11f329f-b395-4702-9b33-449716ea422e",
+ "last_kernel_id": "b6fe1a08-1d4d-40cd-afe6-8352c4e42d25",
+ "last_base_url": "https://bento.edge.x2p.facebook.net/",
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "last_msg_id": "c02547e3-e4c072dc430f066c4d18479a_594",
+ "captumWidgetMessage": [],
+ "outputWidgetContext": [],
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "accelerator": "GPU",
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hBgIy9eYYx35",
+ "originalKey": "4766a371-bf6e-4342-98fb-16dde5255d73",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "## **Open Source Installation** (For Reference)\n",
+ "Requirements:\n",
+ "- python >= 3.9\n",
+ "\n",
+ "We highly recommend CUDA when using TorchRec. If using CUDA:\n",
+ "- cuda >= 11.8\n",
+ "\n",
+ "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sFYvP95xaAER",
+ "originalKey": "27d22c43-9299-46ec-94f2-28a880546fe3",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000131275,
+ "executionStopTime": 1726000131459,
+ "serverExecutionDuration": 2.2683702409267,
+ "requestMsgId": "27d22c43-9299-46ec-94f2-28a880546fe3"
+ },
+ "source": [
+ "# Install stable versions for best reliability\n",
+ "\n",
+ "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n",
+ "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n",
+ "!pip3 install torchmetrics==1.0.3\n",
+ "!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-4DFtQNDYao1",
+ "originalKey": "07e2a5ae-9ca2-45d7-af10-84d8e09ce91e",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "# Intro to TorchRec\n",
+ "\n",
+ "### Embeddings\n",
+ "When building recommendation systems, categorical features typically have massive cardinalities, posts, users, ads, etc.\n",
+ "\n",
+ "In order to represent these entities and model these relationships, **embeddings** are used. In machine learning, **embeddings are a vectors of real numbers in a high-dimensional space used to represent meaning in complex data like words, images, or users**.\n",
+ "\n",
+ "\n",
+ "### Embeddings in RecSys\n",
+ "\n",
+ "Now you might wonder, how are these embeddings generated in the first place? Well, embeddings are represented as individual rows in an **Embedding Table**, also referred to as embedding weights. The reason for this is that embeddings/embedding table weights are trained just like all of the other weights of the model via gradient descent!\n",
+ "\n",
+ "Embedding tables are simply a large matrix for storing embeddings, with two dimensions (B, N), where\n",
+ "* B is the number of embeddings stored by the table\n",
+ "* N is the number of dimensions per embedding (N-dimensional embedding).\n",
+ "\n",
+ "\n",
+ "The inputs to embedding tables represent embedding lookups to retrieve the embedding for a specific index/row. In recommendation systems, such as those used in Meta, unique IDs are not only used for specific users, but also across entites like posts and ads to serve as lookup indices to respective embedding tables!\n",
+ "\n",
+ "Embeddings are trained in RecSys through the following process:\n",
+ "1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are hashed to the total size of the embedding table to prevent issues when the ID > # of rows\n",
+ "2. Embeddings are then retrieved and **pooled, such as taking the sum or mean of the embeddings**. This is required as there can be a variable # of embeddings per example while the model expects consistent shapes.\n",
+ "3. The **embeddings are used in conjunction with the rest of the model to produce a prediction**, such as [Click-Through Rate (CTR)](https://support.google.com/google-ads/answer/2615875?hl=en) for an Ad.\n",
+ "4. The loss is calculated with the prediction and the label for an example, and **all weights of the model are updated through gradient descent and backpropogation, including the embedding weights** that were associated with the example.\n",
+ "\n",
+ "These embeddings are crucial for representing categorical features, such as users, posts, and ads, in order to capture relationships and make good recommendations. Meta AI's [Deep learning recommendation model](https://arxiv.org/abs/1906.00091) (DLRM) paper talks more about the technical details of using embedding tables in RecSys.\n",
+ "\n",
+ "This tutorial will introduce the concept of embeddings, showcase TorchRec specific modules/datatypes, and depict how distributed training works with TorchRec."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "48b50971-aeab-4754-8cff-986496689f43",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000131464,
+ "executionStopTime": 1726000133971,
+ "serverExecutionDuration": 2349.9959111214,
+ "requestMsgId": "48b50971-aeab-4754-8cff-986496689f43",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "1534047040582458"
+ },
+ "id": "AbeT4W9xcso9"
+ },
+ "source": [
+ "import torch"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4b510f99-840d-4986-b635-33c21af48cf4",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "bjuDdEqocso-"
+ },
+ "source": [
+ "## Embeddings in PyTorch\n",
+ "[`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html): Embedding table where forward pass returns the embeddings themselves as is.\n",
+ "\n",
+ "[`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html): Embedding table where forward pass returns embeddings that are then pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**\n",
+ "\n",
+ "In this section, we will go over a very brief introduction with doing embedding lookups through passing in indices into the table. Check out the links for each for more sophisticated use cases and experiments!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000133982,
+ "executionStopTime": 1726000134201,
+ "serverExecutionDuration": 31.60185739398,
+ "requestMsgId": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "933119035309629"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1X5C_Dnccso-",
+ "outputId": "616cc153-67ee-4dd6-b1ab-ee6ff6f44709"
+ },
+ "source": [
+ "num_embeddings, embedding_dim = 10, 4\n",
+ "\n",
+ "# Initialize our embedding table\n",
+ "weights = torch.rand(num_embeddings, embedding_dim)\n",
+ "print(\"Weights:\", weights)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Weights: tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134203,
+ "executionStopTime": 1726000134366,
+ "serverExecutionDuration": 8.956927806139,
+ "requestMsgId": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "831419729143778"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "bxszzeGdcso-",
+ "outputId": "88f21deb-c2f1-4894-975e-fdac41436b36"
+ },
+ "source": [
+ "# Pass in pre generated weights just for example, typically weights are randomly initialized\n",
+ "embedding_collection = torch.nn.Embedding(\n",
+ " num_embeddings, embedding_dim, _weight=weights\n",
+ ")\n",
+ "embedding_bag_collection = torch.nn.EmbeddingBag(\n",
+ " num_embeddings, embedding_dim, _weight=weights\n",
+ ")\n",
+ "\n",
+ "# Print out the tables, we should see the same weights as above\n",
+ "print(\"Embedding Collection Table: \", embedding_collection.weight)\n",
+ "print(\"Embedding Bag Collection Table: \", embedding_bag_collection.weight)\n",
+ "\n",
+ "# Lookup rows (ids for embedding ids) from the embedding tables\n",
+ "# 2D tensor with shape (batch_size, ids for each batch)\n",
+ "ids = torch.tensor([[1, 3]])\n",
+ "print(\"Input row IDS: \", ids)"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Collection Table: Parameter containing:\n",
+ "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
+ "Embedding Bag Collection Table: Parameter containing:\n",
+ "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
+ "Input row IDS: tensor([[1, 3]])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "cb5c5906-e9a6-4315-b860-b263e08989be",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134369,
+ "executionStopTime": 1726000134545,
+ "serverExecutionDuration": 5.9817284345627,
+ "requestMsgId": "cb5c5906-e9a6-4315-b860-b263e08989be",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "2201664893536578"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xkedJeTOcso_",
+ "outputId": "46215f2b-03ad-421b-f873-78b2be0df4d4"
+ },
+ "source": [
+ "embeddings = embedding_collection(ids)\n",
+ "\n",
+ "# Print out the embedding lookups\n",
+ "# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above\n",
+ "print(\"Embedding Collection Results: \")\n",
+ "print(embeddings)\n",
+ "print(\"Shape: \", embeddings.shape)"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Collection Results: \n",
+ "tensor([[[0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401]]], grad_fn=)\n",
+ "Shape: torch.Size([1, 2, 4])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134547,
+ "executionStopTime": 1726000134718,
+ "serverExecutionDuration": 7.8675262629986,
+ "requestMsgId": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "6449977515126116"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PmtJkxLccso_",
+ "outputId": "c33109b4-205e-492d-e0ca-94d2887ec6e7"
+ },
+ "source": [
+ "# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above\n",
+ "pooled_embeddings = embedding_bag_collection(ids)\n",
+ "\n",
+ "print(\"Embedding Bag Collection Results: \")\n",
+ "print(pooled_embeddings)\n",
+ "print(\"Shape: \", pooled_embeddings.shape)\n",
+ "\n",
+ "# nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc.)\n",
+ "# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection\n",
+ "print(\"Mean: \", torch.mean(embedding_collection(ids), dim=1))"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Bag Collection Results: \n",
+ "tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n",
+ "Shape: torch.Size([1, 4])\n",
+ "Mean: tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4643305e-2770-40cf-afc6-e64cd3f51063",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "SuCV1cJ8cso_"
+ },
+ "source": [
+ "Congratulations! Now you have a basic understanding on how to use embedding tables --- one of the foundations of modern recommendation systems! These tables represent entities and their relationships. For example, the relationship between a given user and the pages & posts they have liked."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "7dfcffeb-c7c0-4d74-9dba-569c1d882898",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "QIuAYSZ5cso_"
+ },
+ "source": [
+ "# TorchRec\n",
+ "\n",
+ "Now you know how to use embedding tables, one of the foundations of modern recommendation systems! These tables represent entities and relationships, such as users, pages, posts, etc. Given that these entities are always increasing, a **hash** function is typically applied to make sure the ids are within the bounds of a certain embedding table. However, in order to represent a vast amount of entities and reduce hash collisions, these tables can become quite massive (think about # of ads for example). In fact, these tables can become so massive that they won't be able to fit on 1 GPU, even with 80G of memory!\n",
+ "\n",
+ "In order to train models with massive embedding tables, sharding these tables across GPUs is required, which then introduces a whole new set of problems/opportunities in parallelism and optimization. Luckily, we have the TorchRec library that has encountered, consolidated, and addressed many of these concerns. TorchRec serves as a **library that provides primitives for large scale distributed embeddings**.\n",
+ "\n",
+ "From here on out, we will explore the major features of the TorchRec library. We will start with torch.nn.Embedding and will extend that to custom TorchRec modules, explore distributed training environment with generating a sharding plan for embeddings, look at inherent TorchRec optimizations, and extend the model to be ready for inference in C++. Below is a quick outline of what the journey will consist of - buckle in!\n",
+ "\n",
+ "1. TorchRec Modules and DataTypes\n",
+ "2. Distributed Training, Sharding, and Optimizations\n",
+ "3. Inference\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "8395ed9c-8336-4686-8e73-cb815b808f2a",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134724,
+ "executionStopTime": 1726000139238,
+ "serverExecutionDuration": 4317.9145939648,
+ "requestMsgId": "8395ed9c-8336-4686-8e73-cb815b808f2a",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "5vzmNV0IcspA"
+ },
+ "source": [
+ "import torchrec"
+ ],
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "0c95b385-e07a-43e1-aaeb-31f66deb5b35",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "42PwMZnNcspA"
+ },
+ "source": [
+ "## TorchRec Modules and Datatypes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZdSUWBRxoP8R",
+ "originalKey": "309c4d38-8f19-46d9-a8bb-7d3d1c166e84",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "### From EmbeddingBag to EmbeddingBagCollection\n",
+ "We have already explored [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).\n",
+ "\n",
+ "TorchRec extends these modules by creating collections of embeddings, in other words modules that can have multiple embedding tables, with [`EmbeddingCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection) and [`EmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection). We will use `EmbeddingBagCollection` to represent a group of EmbeddingBags.\n",
+ "\n",
+ "Here, we create an EmbeddingBagCollection (EBC) with two embedding bags, 1 representing **products** and 1 representing **users**. Each table, `product_table` and `user_table`, is represented by 64 dimension embedding of size 4096."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Iz_GZDp_oQ19",
+ "originalKey": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000139247,
+ "executionStopTime": 1726000139433,
+ "serverExecutionDuration": 13.643965125084,
+ "requestMsgId": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
+ "output": {
+ "id": "1615870128957785"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "cec92f4f-d9eb-464e-8e22-26fd910bf8c1"
+ },
+ "source": [
+ "ebc = torchrec.EmbeddingBagCollection(\n",
+ " device=\"cpu\",\n",
+ " tables=[\n",
+ " torchrec.EmbeddingBagConfig(\n",
+ " name=\"product_table\",\n",
+ " embedding_dim=64,\n",
+ " num_embeddings=4096,\n",
+ " feature_names=[\"product\"],\n",
+ " pooling=torchrec.PoolingType.SUM,\n",
+ " ),\n",
+ " torchrec.EmbeddingBagConfig(\n",
+ " name=\"user_table\",\n",
+ " embedding_dim=64,\n",
+ " num_embeddings=4096,\n",
+ " feature_names=[\"user\"],\n",
+ " pooling=torchrec.PoolingType.SUM,\n",
+ " )\n",
+ " ]\n",
+ ")\n",
+ "print(ebc.embedding_bags)"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "c587a298-4d38-4a69-89a2-5d5c4a26cc2c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "xjcA0Di1cspA"
+ },
+ "source": [
+ "Let’s inspect the forward method for EmbeddingBagcollection and the module’s inputs and outputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c9d2717b-b753-4e0b-97bd-1596123d081d",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139437,
+ "executionStopTime": 1726000139616,
+ "serverExecutionDuration": 6.011176854372,
+ "requestMsgId": "c9d2717b-b753-4e0b-97bd-1596123d081d",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "398959426640405"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UuIrEWupcspA",
+ "outputId": "300c86c2-82c1-4657-fa1a-6a319eb40177"
+ },
+ "source": [
+ "import inspect\n",
+ "\n",
+ "# Let's look at the EmbeddingBagCollection forward method\n",
+ "# What is a KeyedJaggedTensor and KeyedTensor?\n",
+ "print(inspect.getsource(ebc.forward))"
+ ],
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " features (KeyedJaggedTensor): KJT of form [F X B X L].\n",
+ "\n",
+ " Returns:\n",
+ " KeyedTensor\n",
+ " \"\"\"\n",
+ " if is_non_strict_exporting() and not torch.jit.is_scripting():\n",
+ " return self._non_strict_exporting_forward(features)\n",
+ " flat_feature_names: List[str] = []\n",
+ " for names in self._feature_names:\n",
+ " flat_feature_names.extend(names)\n",
+ " inverse_indices = reorder_inverse_indices(\n",
+ " inverse_indices=features.inverse_indices_or_none(),\n",
+ " feature_names=flat_feature_names,\n",
+ " )\n",
+ " pooled_embeddings: List[torch.Tensor] = []\n",
+ " feature_dict = features.to_dict()\n",
+ " for i, embedding_bag in enumerate(self.embedding_bags.values()):\n",
+ " for feature_name in self._feature_names[i]:\n",
+ " f = feature_dict[feature_name]\n",
+ " res = embedding_bag(\n",
+ " input=f.values(),\n",
+ " offsets=f.offsets(),\n",
+ " per_sample_weights=f.weights() if self._is_weighted else None,\n",
+ " ).float()\n",
+ " pooled_embeddings.append(res)\n",
+ " return KeyedTensor(\n",
+ " keys=self._embedding_names,\n",
+ " values=process_pooled_embeddings(\n",
+ " pooled_embeddings=pooled_embeddings,\n",
+ " inverse_indices=inverse_indices,\n",
+ " ),\n",
+ " length_per_key=self._lengths_per_embedding,\n",
+ " )\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "d6b9bfc2-544d-499f-ad61-d7471b819f8a",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "C_UAtHsMcspA"
+ },
+ "source": [
+ "### TorchRec Input/Output Data Types\n",
+ "TorchRec has distinct data types for input and output of its modules: `JaggedTensor`, `KeyedJaggedTensor`, and `KeyedTensor`. Now you might ask, why create new datatypes to represent sparse features? To answer that question, we must understand how sparse features are represented in code.\n",
+ "\n",
+ "Sparse features are otherwise known as `id_list_feature` and `id_score_list_feature`, and are the **IDs** that will be used as indices to an embedding table to retrieve the embedding for that ID. To give a very simple example, imagine a single sparse feature being Ads that a user interacted with. The input itself would be a set of Ad IDs that a user interacted with, and the embeddings retrieved would be a semantic representation of those Ads. The tricky part of representing these features in code is that in each input example, **the number of IDs is variable**. 1 day a user might have interacted with only 1 ad while the next day they interact with 3.\n",
+ "\n",
+ "A simple representation is shown below, where we have a `lengths` tensor denoting how many indices are in an example for a batch and a `values` tensor containing the indices themselves.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "13225ead-a798-4db2-8de6-1c13a758d676",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139620,
+ "executionStopTime": 1726000139790,
+ "serverExecutionDuration": 3.692839294672,
+ "requestMsgId": "13225ead-a798-4db2-8de6-1c13a758d676",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "RB77aL08cspA"
+ },
+ "source": [
+ "# Batch Size 2\n",
+ "# 1 ID in example 1, 2 IDs in example 2\n",
+ "id_list_feature_lengths = torch.tensor([1, 2])\n",
+ "\n",
+ "# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2\n",
+ "id_list_feature_values = torch.tensor([5, 7, 1])"
+ ],
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "65d31fca-7b7f-4c0f-9ca2-56e07243a5c0",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "aKmgGqdNcspA"
+ },
+ "source": [
+ "Let’s look at the offsets as well as what is contained in each Batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "9510cebd-1875-461e-9243-53928632abfa",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139794,
+ "executionStopTime": 1726000139966,
+ "serverExecutionDuration": 6.6289491951466,
+ "requestMsgId": "9510cebd-1875-461e-9243-53928632abfa",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "869913611744322"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "t5T5S8_mcspB",
+ "outputId": "87e78b11-7497-4387-c0bc-b4d277ba8ab3"
+ },
+ "source": [
+ "# Lengths can be converted to offsets for easy indexing of values\n",
+ "id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)\n",
+ "\n",
+ "print(\"Offsets: \", id_list_feature_offsets)\n",
+ "print(\"First Batch: \", id_list_feature_values[: id_list_feature_offsets[0]])\n",
+ "print(\n",
+ " \"Second Batch: \",\n",
+ " id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],\n",
+ ")"
+ ],
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Offsets: tensor([1, 3])\n",
+ "First Batch: tensor([5])\n",
+ "Second Batch: tensor([7, 1])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139968,
+ "executionStopTime": 1726000140161,
+ "serverExecutionDuration": 7.3191449046135,
+ "requestMsgId": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1254783359215069"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2OOK2BBecspB",
+ "outputId": "b27c1547-fbb7-47c4-efb6-48aff3300d1a"
+ },
+ "source": [
+ "from torchrec import JaggedTensor\n",
+ "\n",
+ "# JaggedTensor is just a wrapper around lengths/offsets and values tensors!\n",
+ "jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)\n",
+ "\n",
+ "# Automatically compute offsets from lengths\n",
+ "print(\"Offsets: \", jt.offsets())\n",
+ "\n",
+ "# Convert to list of values\n",
+ "print(\"List of Values: \", jt.to_dense())\n",
+ "\n",
+ "# __str__ representation\n",
+ "print(jt)"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Offsets: tensor([0, 1, 3])\n",
+ "List of Values: [tensor([5]), tensor([7, 1])]\n",
+ "JaggedTensor({\n",
+ " [[5], [7, 1]]\n",
+ "})\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "ad069058-2329-4ab9-bee8-60775ead4c33",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140165,
+ "executionStopTime": 1726000140355,
+ "serverExecutionDuration": 10.361641645432,
+ "requestMsgId": "ad069058-2329-4ab9-bee8-60775ead4c33",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "530006499497328"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fs10Fxu2cspB",
+ "outputId": "3bc754c2-ac30-4c01-b0c8-7f98eedb9c52"
+ },
+ "source": [
+ "from torchrec import KeyedJaggedTensor\n",
+ "\n",
+ "# JaggedTensor represents IDs for 1 feature, but we have multiple features in an EmbeddingBagCollection\n",
+ "# That's where KeyedJaggedTensor comes in! KeyedJaggedTensor is just multiple JaggedTensors for multiple id_list_feature_offsets\n",
+ "# From before, we have our two features \"product\" and \"user\". Let's create JaggedTensors for both!\n",
+ "\n",
+ "product_jt = JaggedTensor(\n",
+ " values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])\n",
+ ")\n",
+ "user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))\n",
+ "\n",
+ "# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?\n",
+ "kjt = KeyedJaggedTensor.from_jt_dict({\"product\": product_jt, \"user\": user_jt})\n",
+ "\n",
+ "# Look at our feature keys for the KeyedJaggedTensor\n",
+ "print(\"Keys: \", kjt.keys())\n",
+ "\n",
+ "# Look at the overall lengths for the KeyedJaggedTensor\n",
+ "print(\"Lengths: \", kjt.lengths())\n",
+ "\n",
+ "# Look at all values for KeyedJaggedTensor\n",
+ "print(\"Values: \", kjt.values())\n",
+ "\n",
+ "# Can convert KJT to dictionary representation\n",
+ "print(\"to_dict: \", kjt.to_dict())\n",
+ "\n",
+ "# KeyedJaggedTensor(KJT) string representation\n",
+ "print(kjt)\n",
+ "\n",
+ "# Q2: What are the offsets for the KeyedJaggedTensor?"
+ ],
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Keys: ['product', 'user']\n",
+ "Lengths: tensor([3, 1, 2, 2])\n",
+ "Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])\n",
+ "to_dict: {'product': , 'user': }\n",
+ "KeyedJaggedTensor({\n",
+ " \"product\": [[1, 2, 1], [5]],\n",
+ " \"user\": [[2, 3], [4, 1]]\n",
+ "})\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140357,
+ "executionStopTime": 1726000140549,
+ "serverExecutionDuration": 17.695877701044,
+ "requestMsgId": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "496557126663787"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JeLwyCNRcspB",
+ "outputId": "0723906d-0aba-4d48-e9a5-7ac618d711c5"
+ },
+ "source": [
+ "# Now we can run a forward pass on our ebc from before\n",
+ "result = ebc(kjt)\n",
+ "result"
+ ],
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 14
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "57a01464-de39-4bfb-8355-83cd97e519c0",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140552,
+ "executionStopTime": 1726000140732,
+ "serverExecutionDuration": 6.0368701815605,
+ "requestMsgId": "57a01464-de39-4bfb-8355-83cd97e519c0",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1457290878317732"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "R2K4v2vqcspB",
+ "outputId": "c507e47e-32bc-4440-8c8a-2b3ea334467c"
+ },
+ "source": [
+ "# Result is a KeyedTensor, which contains a list of the feature names and the embedding results\n",
+ "print(result.keys())\n",
+ "\n",
+ "# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined\n",
+ "# 128 for dimension of embedding. If you look at where we initialized the EmbeddingBagCollection, we have two tables \"product\" and \"user\" of dimension 64 each\n",
+ "# meaning emebddings for both features are of size 64. 64 + 64 = 128\n",
+ "print(result.values().shape)\n",
+ "\n",
+ "# Nice to_dict method to determine the embeddings that belong to each feature\n",
+ "result_dict = result.to_dict()\n",
+ "for key, embedding in result_dict.items():\n",
+ " print(key, embedding.shape)"
+ ],
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "['product', 'user']\n",
+ "torch.Size([2, 128])\n",
+ "product torch.Size([2, 64])\n",
+ "user torch.Size([2, 64])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "d0fc8635-dac3-444b-978b-421b5d77b70c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "EE-YYDv7cspB"
+ },
+ "source": [
+ "Congrats! Give yourself a pat on the back for making it this far."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "70816a78-7671-411c-814f-d2c98c3a912c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "djLHn0CIcspB"
+ },
+ "source": [
+ "## Distributed Training and Sharding\n",
+ "Now that we have a grasp on TorchRec modules and data types, it's time to take it to the next level.\n",
+ "\n",
+ "Remember, TorchRec's main purpose is to provide primitives for distributed embeddings. So far, we've only worked with embedding tables on 1 device. This has been possible given how small the embedding tables have been, but in a production setting this isn't generally the case. Embedding tables often get massive, where 1 table can't fit on a single GPU, creating the requirement for multiple devices and a distributed environment\n",
+ "\n",
+ "In this section, we will explore setting up a distributed environment, exactly how actual production training is done, and explore sharding embedding tables, all with Torchrec.\n",
+ "\n",
+ "**This section will also only use 1 gpu, though it will be treated in a distributed fashion. This is only a limitation for training, as training has a process per gpu. Inference does not run into this requirement**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4-v17rxkopQw",
+ "originalKey": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000140740,
+ "executionStopTime": 1726000142256,
+ "serverExecutionDuration": 1350.0418178737,
+ "requestMsgId": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
+ "output": {
+ "id": "1195358511578142"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "efa5689f-b794-41a6-8c06-86856e1e698a"
+ },
+ "source": [
+ "# Here we set up our torch distributed environment\n",
+ "# WARNING: You can only call this cell once, calling it again will cause an error\n",
+ "# as you can only initialize the process group once\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "import torch.distributed as dist\n",
+ "\n",
+ "# Set up environment variables for distributed training\n",
+ "# RANK is which GPU we are on, default 0\n",
+ "os.environ[\"RANK\"] = \"0\"\n",
+ "# How many devices in our \"world\", since Bento can only handle 1 process, 1 GPU\n",
+ "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
+ "# Localhost as we are training locally\n",
+ "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ "# Port for distributed training\n",
+ "os.environ[\"MASTER_PORT\"] = \"29500\"\n",
+ "\n",
+ "# Note - you will need a V100 or A100 to run tutorial as!\n",
+ "# nccl backend is for GPUs, gloo is for CPUs\n",
+ "dist.init_process_group(backend=\"gloo\")\n",
+ "\n",
+ "print(f\"Distributed environment initialized: {dist}\")"
+ ],
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Distributed environment initialized: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "480e69dc-3e9d-4e86-b73c-950e18afb0f5",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "hQOjNci3cspB"
+ },
+ "source": [
+ "### Distributed Embeddings\n",
+ "\n",
+ "We have already worked with the main TorchRec module: `EmbeddingBagCollection`. We have examined how it works along with how data is represented in TorchRec. However, we have not yet explored one of the main parts of TorchRec, which is **distributed embeddings**.\n",
+ "\n",
+ "GPUs are the most popular choice for ML workloads by far today, as they are able to do magnitudes more floating point operations/s ([FLOPs](https://en.wikipedia.org/wiki/FLOPS)) than CPU. However, GPUs come with the limitation of scarce fast memory (HBM which is analgous to RAM for CPU), typically ~10s of GBs.\n",
+ "\n",
+ "A RecSys model can contain embedding tables that far exceed the memory limit for 1 GPU, hence the need for distribution of the embedding tables across multiple GPUs, otherwise known as **model parallel**. On the other hand, **data parallel** is where the entire model is replicated on each GPU, which each GPU taking in a distinct batch of data for training, syncing gradients on the backwards pass.\n",
+ "\n",
+ "Parts of the model that **require less compute but more memory (embeddings) are distributed with model parallel** while parts that **require more compute and less memory (dense layers, MLP, etc.) are distributed with data parallel**.\n",
+ "\n",
+ "\n",
+ "### Sharding\n",
+ "In order to distribute an embedding table, we split up the embedding table into parts and place those parts onto different devices, also known as “sharding”.\n",
+ "\n",
+ "There are many ways to shard embedding tables. The most common ways are:\n",
+ "* Table-Wise: the table is placed entirely onto one device\n",
+ "* Column-Wise: columns of embedding tables are sharded\n",
+ "* Row-Wise: rows of embedding tables are sharded\n",
+ "\n",
+ "\n",
+ "### Sharded Modules\n",
+ "While all of this seems like a lot to deal with and implement, you're in luck. **TorchRec provides all the primitives for easy distributed training/inference**! In fact, TorchRec modules have two corresponding classes for working with any TorchRec module in a distributed environment:\n",
+ "1. The module sharder: This class exposes a `shard` API that handles sharding a TorchRec Module, producing a sharded module.\n",
+ " * For `EmbeddingBagCollection`, the sharder is [`EmbeddingBagCollectionSharder`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder)\n",
+ "2. Sharded module: This class is a sharded variant of a TorchRec module. It has the same input/output as a the regular TorchRec module, but much more optimized and works in a distributed environment.\n",
+ " * For `EmbeddingBagCollection`, the sharded variant is [`ShardedEmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection)\n",
+ "\n",
+ "Every TorchRec module has an unsharded and sharded variant.\n",
+ "* The unsharded version is meant to be prototyped and experimented with\n",
+ "* The sharded version is meant to be used in a distributed environment for distributed training/inference.\n",
+ "\n",
+ "The sharded versions of TorchRec modules, for example EmbeddingBagCollection, will handle everything that is needed for Model Parallelism, such as communication between GPUs for distributing embeddings to the correct GPUs.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000142261,
+ "executionStopTime": 1726000142430,
+ "serverExecutionDuration": 8.3460621535778,
+ "requestMsgId": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "791089056311464"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FX65VcQ6cspB",
+ "outputId": "1aa3bc52-569b-46fb-8a94-cd1873e987ca"
+ },
+ "source": [
+ "# Refresher of our EmbeddingBagCollection module\n",
+ "ebc"
+ ],
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "1442636d-8617-4785-b40c-8544374253b6",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "outputsInitialized": true,
+ "executionStartTime": 1726000142433,
+ "executionStopTime": 1726000142681,
+ "serverExecutionDuration": 4.4135116040707,
+ "requestMsgId": "1442636d-8617-4785-b40c-8544374253b6",
+ "customOutput": null,
+ "output": {
+ "id": "502189589096046"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1hSzTg4pcspC",
+ "outputId": "d7d86592-4fdc-4d0b-f2ba-a40a80af1fcf"
+ },
+ "source": [
+ "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n",
+ "from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\n",
+ "from torchrec.distributed.types import ShardingEnv\n",
+ "\n",
+ "# Corresponding sharder for EmbeddingBagCollection module\n",
+ "sharder = EmbeddingBagCollectionSharder()\n",
+ "\n",
+ "# ProcessGroup from torch.distributed initialized 2 cells above\n",
+ "pg = dist.GroupMember.WORLD\n",
+ "assert pg is not None, \"Process group is not initialized\"\n",
+ "\n",
+ "print(f\"Process Group: {pg}\")"
+ ],
+ "execution_count": 18,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Process Group: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "29cc17eb-9e2f-480b-aed2-60b15024fbf7",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "qU7A980qcspC"
+ },
+ "source": [
+ "### Planner\n",
+ "\n",
+ "Before we can show how sharding works, we must know about the **planner**, which helps us determine the best sharding configuration.\n",
+ "\n",
+ "Given a number of embedding tables and a number of ranks, there are many different sharding configurations that are possible. For example, given 2 embedding tables and 2 GPUs, you can:\n",
+ "* Place 1 table on each GPU\n",
+ "* Place both tables on a single GPU and no tables on the other\n",
+ "* Place certain rows/columns on each GPU\n",
+ "\n",
+ "Given all of these possibilities, we typically want a sharding configuration that is optimal for performance.\n",
+ "\n",
+ "That is where the planner comes in. The planner is able to determine given the # of embedding tables and the # of GPUs, what is the optimal configuration. Turns out, this is incredibly difficult to do manually, with tons of factors that engineers have to consider to ensure an optimal sharding plan. Luckily, TorchRec provides an auto planner when the planner is used. The TorchRec planner:\n",
+ "* assesses memory constraints of hardware,\n",
+ "* estimates compute based on memory fetches as embedding lookups,\n",
+ "* addresses data specific factors,\n",
+ "* considers other hardware specifics like bandwidth to generate an optimal sharding plan.\n",
+ "\n",
+ "In order to take into consideration all these variables, The TorchRec planner can take in [various amounts of data for embedding tables, constraints, hardware information, and topology](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155) to aid in generating the optimal sharding plan for a model, which is routinely provided across stacks\n",
+ "\n",
+ "\n",
+ "To learn more about sharding, see our [sharding tutorial](https://pytorch.org/tutorials/advanced/sharding.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "64936203-2e59-4bc3-8d76-1b652b7891c2",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000142687,
+ "executionStopTime": 1726000143033,
+ "serverExecutionDuration": 145.92137932777,
+ "requestMsgId": "64936203-2e59-4bc3-8d76-1b652b7891c2",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1247084956198777"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PQeXnuAGcspC",
+ "outputId": "e6d366b7-f064-4e38-a477-fc2e0d4b9736"
+ },
+ "source": [
+ "# In our case, 1 GPU and compute on CUDA device\n",
+ "planner = EmbeddingShardingPlanner(\n",
+ " topology=Topology(\n",
+ " world_size=1,\n",
+ " compute_device=\"cuda\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# Run planner to get plan for sharding\n",
+ "plan = planner.collective_plan(ebc, [sharder], pg)\n",
+ "\n",
+ "print(f\"Sharding Plan generated: {plan}\")"
+ ],
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharding Plan generated: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "product_table | table_wise | fused | [0] \n",
+ "user_table | table_wise | fused | [0] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "product_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "user_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "bbbbbf60-5691-4357-9943-4d7f8b2b1d5c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "2TTLj_0PcspC"
+ },
+ "source": [
+ "### Planner Result\n",
+ "As you can see, when running the planner there is quite a bit of output above. We can see a ton of stats being calculated along with where our tables end up being placed.\n",
+ "\n",
+ "The result of running the planner is a static plan, which can be reused for sharding! This allows sharding to be static for production models instead of determining a new sharding plan everytime. Below, we use the sharding plan to finally generate our `ShardedEmbeddingBagCollection.`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000143037,
+ "executionStopTime": 1726000143259,
+ "serverExecutionDuration": 5.2368640899658,
+ "requestMsgId": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "901470115170971"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JIci5Gz6cspC",
+ "outputId": "32d3cb73-80b6-4646-9c1d-3cff5a498f86"
+ },
+ "source": [
+ "# The static plan that was generated\n",
+ "plan"
+ ],
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "ShardingPlan(plan={'': {'product_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None), 'user_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None)}})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 20
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000143262,
+ "executionStopTime": 1726000147680,
+ "serverExecutionDuration": 4229.5375689864,
+ "requestMsgId": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1231077634880712"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2__Do2tqcspC",
+ "outputId": "7c19830a-9a11-4fbe-ddf3-8c3cb8a6c3b3"
+ },
+ "source": [
+ "env = ShardingEnv.from_process_group(pg)\n",
+ "\n",
+ "# Shard the EmbeddingBagCollection module using the EmbeddingBagCollectionSharder\n",
+ "sharded_ebc = sharder.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "print(f\"Sharded EBC Module: {sharded_ebc}\")"
+ ],
+ "execution_count": 21,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharded EBC Module: ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [],
+ "metadata": {
+ "id": "ErXXbYzJmVzI"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "3ba44a6d-a6f7-4da2-83a6-e8ac974c64ac",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "QBLpkKYIcspC"
+ },
+ "source": [
+ "#### Awaitable\n",
+ "Remember that TorchRec is a highly optimized library for distributed embeddings. A concept that TorchRec introduces to enable higher performance for training on GPU is a [`LazyAwaitable`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable). You will see `LazyAwaitable` types as outputs of various sharded TorchRec modules. All a `LazyAwaitable` does is delay calculating some result as long as possible, and it does it by acting like an async type."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000147687,
+ "executionStopTime": 1726000147874,
+ "serverExecutionDuration": 9.098757058382,
+ "requestMsgId": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1236006950908310"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rwYzKwyNcspC",
+ "outputId": "a8979bb2-6fac-4db6-b997-c31962014e24"
+ },
+ "source": [
+ "from typing import List\n",
+ "\n",
+ "from torchrec.distributed.types import LazyAwaitable\n",
+ "\n",
+ "\n",
+ "# Demonstrate a LazyAwaitable type\n",
+ "class ExampleAwaitable(LazyAwaitable[torch.Tensor]):\n",
+ " def __init__(self, size: List[int]) -> None:\n",
+ " super().__init__()\n",
+ " self._size = size\n",
+ "\n",
+ " def _wait_impl(self) -> torch.Tensor:\n",
+ " return torch.ones(self._size)\n",
+ "\n",
+ "\n",
+ "awaitable = ExampleAwaitable([3, 2])\n",
+ "awaitable.wait()"
+ ],
+ "execution_count": 22,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[1., 1.],\n",
+ " [1., 1.],\n",
+ " [1., 1.]])"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 22
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000147878,
+ "executionStopTime": 1726000154861,
+ "serverExecutionDuration": 6806.3651248813,
+ "requestMsgId": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1255627342282843"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "cs41RfzGcspC",
+ "outputId": "ba7cfdbe-59c4-48c2-a767-d6f5f5cbb915"
+ },
+ "source": [
+ "kjt = kjt.to(\"cuda\")\n",
+ "output = sharded_ebc(kjt)\n",
+ "# The output of our sharded EmbeddingBagCollection module is a an Awaitable?\n",
+ "print(output)"
+ ],
+ "execution_count": 23,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000154865,
+ "executionStopTime": 1726000155069,
+ "serverExecutionDuration": 6.0432851314545,
+ "requestMsgId": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1057638405967561"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_1sdt75rcspG",
+ "outputId": "365cf6cb-adae-4b41-f68d-1a7acf5fe332"
+ },
+ "source": [
+ "kt = output.wait()\n",
+ "# Now we have out KeyedTensor after calling .wait()\n",
+ "# If you are confused as to why we have a KeyedTensor output,\n",
+ "# give yourself a refresher on the unsharded EmbeddingBagCollection module\n",
+ "print(type(kt))\n",
+ "\n",
+ "print(kt.keys())\n",
+ "\n",
+ "print(kt.values().shape)\n",
+ "\n",
+ "# Same output format as unsharded EmbeddingBagCollection\n",
+ "result_dict = kt.to_dict()\n",
+ "for key, embedding in result_dict.items():\n",
+ " print(key, embedding.shape)"
+ ],
+ "execution_count": 24,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "['product', 'user']\n",
+ "torch.Size([2, 128])\n",
+ "product torch.Size([2, 64])\n",
+ "user torch.Size([2, 64])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4c464de0-20ef-4ef2-89e2-5d58ca224660",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "bEgB987CcspG"
+ },
+ "source": [
+ "### Anatomy of Sharded TorchRec modules\n",
+ "\n",
+ "We have now successfully sharded an EmbeddingBagCollection given a sharding plan that we generated! The sharded module has common APIs from TorchRec which abstract away distributed communication/compute amongst multiple GPUs. In fact, these APIs are highly optimized for performance in training and inference. **Below are the three common APIs for distributed training/inference** that are provided by TorchRec:\n",
+ "\n",
+ "1. **input_dist**: Handles distributing inputs from GPU to GPU\n",
+ "\n",
+ "2. **lookups**: Does the actual embedding lookup in an optimized, batched manner using FBGEMM TBE (more on this later)\n",
+ "\n",
+ "3. **output_dist**: Handles distributing outputs from GPU to GPU\n",
+ "\n",
+ "The distribution of inputs/outputs is done through [NCCL Collectives](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html), namely [All-to-Alls](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all), which is where all GPUs send/receive data to and from one another. TorchRec interfaces with PyTorch distributed for collectives and provides clean abstractions to the end users, removing the concern for the lower level details.\n",
+ "\n",
+ "\n",
+ "The backwards pass does all of these collectives but in the reverse order for distribution of gradients. input_dist, lookup, and output_dist all depend on the sharding scheme. Since we sharded in a table-wise fashion, these APIs are modules that are constructed by [TwPooledEmbeddingSharding](https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155075,
+ "executionStopTime": 1726000155253,
+ "serverExecutionDuration": 5.8192722499371,
+ "requestMsgId": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1042737524520351"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "O2ptES89cspG",
+ "outputId": "2b801648-6501-4463-d743-4887da340974"
+ },
+ "source": [
+ "sharded_ebc"
+ ],
+ "execution_count": 25,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_input_dists): \n",
+ " TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 25
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155256,
+ "executionStopTime": 1726000155442,
+ "serverExecutionDuration": 5.3565315902233,
+ "requestMsgId": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1063399165221115"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PHjJt3BQcspG",
+ "outputId": "aaedadd8-da43-4225-d5f8-a7a43fd0250a"
+ },
+ "source": [
+ "# Distribute input KJTs to all other GPUs and receive KJTs\n",
+ "sharded_ebc._input_dists"
+ ],
+ "execution_count": 26,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 26
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "88abe892-1ed1-4806-84ad-35f43247a772",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155445,
+ "executionStopTime": 1726000155695,
+ "serverExecutionDuration": 5.3521953523159,
+ "requestMsgId": "88abe892-1ed1-4806-84ad-35f43247a772",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1513800839249249"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jrEXMc7TcspG",
+ "outputId": "81ab40a5-135b-494c-f2bc-91be16a338cc"
+ },
+ "source": [
+ "# Distribute output embeddingts to all other GPUs and receive embeddings\n",
+ "sharded_ebc._output_dists"
+ ],
+ "execution_count": 27,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 27
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "2eaf16f1-ac14-4f7a-b443-e707ff85c3f0",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "C2jfo5ilcspH"
+ },
+ "source": [
+ "### Optimizing Embedding Lookups\n",
+ "\n",
+ "In performing lookups for a collection of embedding tables, a trivial solution would be to iterate through all the `nn.EmbeddingBags` and do a lookup per table. This is exactly what the standard, unsharded TorchRec's `EmbeddingBagCollection` does. However, while this solution is simple, it is extremely slow.\n",
+ "\n",
+ "[FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) is a library that provides GPU operators (otherewise known as kernels) that are very optimized. One of these operators is known as **Table Batched Embedding** (TBE), provides two major optimizations:\n",
+ "\n",
+ "* Table batching, which allows you to look up multiple embeddings with one kernel call.\n",
+ "* Optimizer Fusion, which allows the module to update itself given the canonical pytorch optimizers and arguments.\n",
+ "\n",
+ "The `ShardedEmbeddingBagCollection` uses the FBGEMM TBE as the lookup instead of traditional `nn.EmbeddingBags` for optimized embedding lookups."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155699,
+ "executionStopTime": 1726000155879,
+ "serverExecutionDuration": 5.0756596028805,
+ "requestMsgId": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "911093750838903"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1GoHWI6OcspH",
+ "outputId": "cd67815b-00bd-4a30-89cf-7b5d9c7051e9"
+ },
+ "source": [
+ "sharded_ebc._lookups"
+ ],
+ "execution_count": 28,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 28
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "f2b31d78-81a9-426f-b017-ca8404383939",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "1zcbZX1lcspH"
+ },
+ "source": [
+ "### DistributedModelParallel\n",
+ "\n",
+ "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/model-parallel-api-reference.html#model-parallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n",
+ "\n",
+ "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n",
+ "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n",
+ "\n",
+ "DMP takes in everything that we've just experimented with, like a static sharding plan, a list of sharders, etc. However, it also has some nice defaults to seamlessly shard a TorchRec model. In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155883,
+ "executionStopTime": 1726000156073,
+ "serverExecutionDuration": 7.8761726617813,
+ "requestMsgId": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1207953610328397"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ypVUDwpzcspH",
+ "outputId": "26a7a957-c231-459a-dfc8-f0c1cd6f697e"
+ },
+ "source": [
+ "ebc"
+ ],
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 29
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "73fec38d-947a-49d5-a2ba-61e3828b7117",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156075,
+ "executionStopTime": 1726000156438,
+ "serverExecutionDuration": 165.43522849679,
+ "requestMsgId": "73fec38d-947a-49d5-a2ba-61e3828b7117",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1838328716783594"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "5EdlyWAycspH",
+ "outputId": "05e90aa2-cb83-4ddf-9da5-6aa31d6da278"
+ },
+ "source": [
+ "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))"
+ ],
+ "execution_count": 30,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156441,
+ "executionStopTime": 1726000156665,
+ "serverExecutionDuration": 6.8417005240917,
+ "requestMsgId": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1059040285804352"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "b5NgRErjcspH",
+ "outputId": "8f4de40c-3a3b-43e5-d645-814bf03dab0b"
+ },
+ "source": [
+ "out = model(kjt)\n",
+ "out.wait()"
+ ],
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 31
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156669,
+ "executionStopTime": 1726000156885,
+ "serverExecutionDuration": 5.4804161190987,
+ "requestMsgId": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "3346626825643095"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VJrSysjjcspH",
+ "outputId": "2920a3d0-dd96-43ea-ab0d-627de00d1e42"
+ },
+ "source": [
+ "model"
+ ],
+ "execution_count": 32,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DistributedModelParallel(\n",
+ " (_dmp_wrapped_module): ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_input_dists): \n",
+ " TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 32
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4b6171d5-ae60-4cc8-a47a-f01236c02e6c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "BLM673eTcspH"
+ },
+ "source": [
+ "### Sharding Best Practices\n",
+ "\n",
+ "Currently, our configuration is only sharding on 1 GPU (or rank), which is trivial: just place all the tables on 1 GPUs memory. However, in real production use cases, embedding tables are **typically sharded on hundreds of GPUs**, with different sharding methods such as table-wise, row-wise, and column-wise. It is incredibly important to determine a proper sharding configuration (to prevent out of memory issues) while keeping it balanced not only in terms of memory but also compute for optimal performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Adding in the Optimizer\n",
+ "\n",
+ "Remember that TorchRec modules are hyperoptimized for large scale distributed training. An important optimization is in regards to the optimizer. **TorchRec modules provide a seamless API to fuse the backwards pass and optimize step in training, providing a significant optimization in performance and decreasing the memory used, alongside granularity in assigning distinct optimizers to distinct model parameters.**\n",
+ "\n",
+ "#### Optimizer Classes\n",
+ "\n",
+ "TorchRec uses `CombinedOptimizer`, which contains a collection of `KeyedOptimizers`. A `CombinedOptimizer` effectively makes it easy to handle multiple optimizers for various sub groups in the model. A `KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized through a dictionary of parameters exposes the parameters. Each `TBE` module in a `EmbeddingBagCollection` will have it's own `KeyedOptimizer` which combines into one `CombinedOptimizer`.\n",
+ "\n",
+ "#### Fused optimizer in TorchRec\n",
+ "\n",
+ "Using `DistributedModelParallel`, the **optimizer is fused, which means that the optimizer update is done in the backward**. This is an optimization in TorchRec and FBGEMM, where the optimizer embedding gradients are not materialized and applied directly to the parameters. This brings significant memory savings as embedding gradients are typically size of the parameters themselves.\n",
+ "\n",
+ "You can, however, choose to make the optimizer `dense` which does not apply this optimization and let's you inspect the embedding gradients or apply computations to it as you wish. A dense optimizer in this case would be your [canonical PyTorch model training loop with optimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n",
+ "\n",
+ "Once the optimizer is created through `DistributedModelParallel`, you still need to manage an optimizer for the other parameters not associated with TorchRec embedding modules. To find the other parameters, use`in_backward_optimizer_filter(model.named_parameters())`.\n",
+ "\n",
+ "Apply an optimizer to those parameters as you would a normal Torch optimizer and combine this and the `model.fused_optimizer` into one `CombinedOptimizer` that you can use in your training loop to `zero_grad` and `step` through.\n",
+ "\n",
+ "#### Let's add an optimizer to our EmbeddingBagCollection\n",
+ "We will do this in two ways, which are equivalent, but give you options depending on your preferences:\n",
+ "1. Passing optimizer kwargs through fused parameters (fused_params) in sharder\n",
+ "2. Through `apply_optimizer_in_backward`\n",
+ "Note: `apply_optimizer_in_backward` converts the optimizer parameters to `fused_params` to pass to the `TBE` in the `EmbeddingBagCollection`/`EmbeddingCollection`."
+ ],
+ "metadata": {
+ "id": "zFhggkUCmd7f"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Approach 1: passing optimizer kwargs through fused parameters\n",
+ "from torchrec.optim.optimizers import in_backward_optimizer_filter\n",
+ "from fbgemm_gpu.split_embedding_configs import EmbOptimType\n",
+ "\n",
+ "\n",
+ "# We initialize the sharder with\n",
+ "fused_params = {\n",
+ " \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n",
+ " \"learning_rate\": 0.02,\n",
+ " \"eps\": 0.002,\n",
+ "}\n",
+ "\n",
+ "# Init sharder with fused_params\n",
+ "sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n",
+ "\n",
+ "# We'll use same plan and unsharded EBC as before but this time with our new sharder\n",
+ "sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty.\n",
+ "# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\n",
+ "print(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\n",
+ "print(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n",
+ "\n",
+ "print(f\"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "h5BCEFidmnEw",
+ "outputId": "202c64f7-ae95-4b0d-9f53-16138a680d7d"
+ },
+ "execution_count": 33,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.01\n",
+ ")\n",
+ "Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.02\n",
+ ")\n",
+ "Type of optimizer: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n",
+ "import copy\n",
+ "# Approach 2: applying optimizer through apply_optimizer_in_backward\n",
+ "# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n",
+ "\n",
+ "# We can achieve the same result as we did in the previous\n",
+ "ebc_apply_opt = copy.deepcopy(ebc)\n",
+ "optimizer_kwargs = {\"lr\": 0.5}\n",
+ "\n",
+ "for name, param in ebc_apply_opt.named_parameters():\n",
+ " print(f\"{name=}\")\n",
+ " apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n",
+ "\n",
+ "sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\n",
+ "print(sharded_ebc_apply_opt.fused_optimizer)\n",
+ "print(type(sharded_ebc_apply_opt.fused_optimizer))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "T-xx724MmoKv",
+ "outputId": "0f58fb18-f423-4c84-ee57-d37bdba28eb8"
+ },
+ "execution_count": 34,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "name='embedding_bags.product_table.weight'\n",
+ "name='embedding_bags.user_table.weight'\n",
+ ": EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.5\n",
+ ")\n",
+ "\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":1: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.\n",
+ " from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# We can also check through the filter other parameters that aren't associated with the \"fused\" optimizer(s)\n",
+ "# Pratically, just non TorchRec module parameters. Since our module is just a TorchRec EBC\n",
+ "# there are no other parameters that aren't associated with TorchRec\n",
+ "print(\"Non Fused Model Parameters:\")\n",
+ "print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UEyhlmbmlwsW",
+ "outputId": "f6219673-d14d-444e-a451-98f33ddeb54d"
+ },
+ "execution_count": 35,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Non Fused Model Parameters:\n",
+ "dict_keys([])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Here we do a dummy backwards call and see that parameter updates for fused\n",
+ "# optimizers happen as a result of the backward pass\n",
+ "\n",
+ "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
+ "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
+ "print(f\"First Iteration Loss: {loss}\")\n",
+ "\n",
+ "loss.backward()\n",
+ "\n",
+ "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
+ "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
+ "# We don't call an optimizer.step(), so for the loss to have changed here,\n",
+ "# that means that the gradients were somehow updated, which is what the\n",
+ "# fused optimizer automatically handles for us\n",
+ "print(f\"Second Iteration Loss: {loss}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bga-zM2OfnMW",
+ "outputId": "6c5d45f2-c479-4932-b39d-5ff8abe27d3c"
+ },
+ "execution_count": 36,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "First Iteration Loss: 255.94378662109375\n",
+ "Second Iteration Loss: 245.72166442871094\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "e3bdc895-54c4-4fc6-9175-28dd75021c6a",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "Xc-RUDwDcspH"
+ },
+ "source": [
+ "## Inference\n",
+ "\n",
+ "Now that we are able to train distributed embeddings, how can we take the trained model and optimize it for inference? Inference is typically very sensitive to **performance and size of the model**. Running just the trained model in a Python environment is incredibly inefficient. There are two key differences between inference and training environments:\n",
+ "* **Quantization**: Inference models are typically quantized, where model parameters lose precision for lower latency in predictions and reduced model size. For example FP32 (4 bytes) in trained model to INT8 (1 byte) for each embedding weight. This is also necessary given the vast scale of embedding tables, as we want to use as few devices as possible for inference to minimize latency.\n",
+ "* **C++ environment**: Inference latency is a big deal, so in order to ensure ample performance, the model is typically ran in a C++ environment (along with situations where we don't have a Python runtime, like on device)\n",
+ "\n",
+ "TorchRec provides primitives for converting a TorchRec model into being inference ready with:\n",
+ "* APIs for quantizing the model, introducing optimizations automatically with FBGEMM TBE\n",
+ "* sharding embeddings for distributed inference\n",
+ "* compiling the model to [TorchScript](https://pytorch.org/docs/stable/jit.html) (compatible in C++)\n",
+ "\n",
+ "In this section, we will go over this entire workflow of:\n",
+ "* Quantizing the model\n",
+ "* Sharding the quantized model\n",
+ "* Compiling the sharded quantized model into TorchScript"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "outputsInitialized": true,
+ "executionStartTime": 1726000156892,
+ "executionStopTime": 1726000157069,
+ "serverExecutionDuration": 7.4504055082798,
+ "requestMsgId": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
+ "customOutput": null,
+ "output": {
+ "id": "456742254014129"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "8JypsUNmcspH",
+ "outputId": "0a745234-d316-4850-d84a-f08b0f045595"
+ },
+ "source": [
+ "ebc"
+ ],
+ "execution_count": 37,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 37
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "30694976-da54-48d6-922e-ca53f22c385f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157071,
+ "executionStopTime": 1726000157317,
+ "serverExecutionDuration": 2.9501467943192,
+ "requestMsgId": "30694976-da54-48d6-922e-ca53f22c385f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "t2plfyrWcspH"
+ },
+ "source": [
+ "class InferenceModule(torch.nn.Module):\n",
+ " def __init__(self, ebc: torchrec.EmbeddingBagCollection):\n",
+ " super().__init__()\n",
+ " self.ebc_ = ebc\n",
+ "\n",
+ " def forward(self, kjt: KeyedJaggedTensor):\n",
+ " return self.ebc_(kjt)"
+ ],
+ "execution_count": 38,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157320,
+ "executionStopTime": 1726000157494,
+ "serverExecutionDuration": 3.8229525089264,
+ "requestMsgId": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1619365005294308"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "5FRioGEmcspH",
+ "outputId": "e4d9efd3-2427-4602-bb28-2f30c4f3f985"
+ },
+ "source": [
+ "module = InferenceModule(ebc)\n",
+ "for name, param in module.named_parameters():\n",
+ " # Here, the parameters should still be FP32, as we are using a standard EBC\n",
+ " # FP32 is default, regularly used for training\n",
+ " print(name, param.shape, param.dtype)"
+ ],
+ "execution_count": 39,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32\n",
+ "ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "665352e2-208f-4951-8601-282d036b0e4e",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "OSTy4SU8cspH"
+ },
+ "source": [
+ "### Quantization\n",
+ "\n",
+ "As you can see above, the normal EBC contains embedding table weights as FP32 precision (32 bits for each weight). Here, we will use the TorchRec inference library to quantize the embedding weights of the model to INT8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "796919b4-f9dd-4d14-a40e-f20668c8257b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157499,
+ "executionStopTime": 1726000157696,
+ "serverExecutionDuration": 14.22468572855,
+ "requestMsgId": "796919b4-f9dd-4d14-a40e-f20668c8257b",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "560049189691202"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "oV-KPRqDcspH",
+ "outputId": "e91220a1-a26c-4f2a-e7c8-e91a3d56d8dc"
+ },
+ "source": [
+ "from torch import quantization as quant\n",
+ "from torchrec.modules.embedding_configs import QuantConfig\n",
+ "from torchrec.quant.embedding_modules import (\n",
+ " EmbeddingBagCollection as QuantEmbeddingBagCollection,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "quant_dtype = torch.int8\n",
+ "\n",
+ "\n",
+ "qconfig = QuantConfig(\n",
+ " # dtype of the result of the embedding lookup, post activation\n",
+ " # torch.float generally for compatability with rest of the model\n",
+ " # as rest of the model here usually isn't quantized\n",
+ " activation=quant.PlaceholderObserver.with_args(dtype=torch.float),\n",
+ " # quantized type for embedding weights, aka parameters to actually quantize\n",
+ " weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),\n",
+ ")\n",
+ "qconfig_spec = {\n",
+ " # Map of module type to qconfig\n",
+ " torchrec.EmbeddingBagCollection: qconfig,\n",
+ "}\n",
+ "mapping = {\n",
+ " # Map of module type to quantized module type\n",
+ " torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,\n",
+ "}\n",
+ "\n",
+ "\n",
+ "module = InferenceModule(ebc)\n",
+ "\n",
+ "# Quantize the module\n",
+ "qebc = quant.quantize_dynamic(\n",
+ " module,\n",
+ " qconfig_spec=qconfig_spec,\n",
+ " mapping=mapping,\n",
+ " inplace=False,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "print(f\"Quantized EBC: {qebc}\")"
+ ],
+ "execution_count": 40,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Quantized EBC: InferenceModule(\n",
+ " (ebc_): QuantizedEmbeddingBagCollection(\n",
+ " (_kjt_to_jt_dict): ComputeKJTToJTDict()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157700,
+ "executionStopTime": 1726000157862,
+ "serverExecutionDuration": 4.0535479784012,
+ "requestMsgId": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "fAztesVacspI"
+ },
+ "source": [
+ "kjt = kjt.to(\"cpu\")"
+ ],
+ "execution_count": 41,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157865,
+ "executionStopTime": 1726000158060,
+ "serverExecutionDuration": 9.1104581952095,
+ "requestMsgId": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "434299789062153"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Wnpwa0TmcspI",
+ "outputId": "88007466-88ce-4f6b-b7e8-22d042e5378b"
+ },
+ "source": [
+ "qebc(kjt)"
+ ],
+ "execution_count": 42,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 42
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "99559efa-baaa-4de1-91d3-7899f87fe659",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158063,
+ "executionStopTime": 1726000158228,
+ "serverExecutionDuration": 3.4465603530407,
+ "requestMsgId": "99559efa-baaa-4de1-91d3-7899f87fe659",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "499581679596627"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UUs5fXNncspI",
+ "outputId": "79d814ed-7772-4bc3-ba84-f5f0d0a45e36"
+ },
+ "source": [
+ "# Once quantized, goes from parameters -> buffers, as no longer trainable\n",
+ "for name, buffer in qebc.named_buffers():\n",
+ " # The shapes of the tables should be the same but the dtype should be int8 now\n",
+ " # post quantization\n",
+ " print(name, buffer.shape, buffer.dtype)"
+ ],
+ "execution_count": 43,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8\n",
+ "ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "2b1a9c89-b921-4a35-9f64-0c63b09a2579",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "fdM7UihocspI"
+ },
+ "source": [
+ "### Shard\n",
+ "\n",
+ "Here we perform sharding of the TorchRec quantized model. This is to ensure we are using the performant module through FBGEMM TBE. Here we are using one device to be consistent with training (1 TBE)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158234,
+ "executionStopTime": 1726000158552,
+ "serverExecutionDuration": 108.51271077991,
+ "requestMsgId": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "882684747065056"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "mha4FntncspI",
+ "outputId": "40a5557c-531f-4726-f06a-5f73546a8fe0"
+ },
+ "source": [
+ "from torchrec import distributed as trec_dist\n",
+ "from torchrec.distributed.shard import _shard_modules\n",
+ "\n",
+ "\n",
+ "sharded_qebc = _shard_modules(\n",
+ " module=qebc,\n",
+ " device=torch.device(\"cpu\"),\n",
+ " env=trec_dist.ShardingEnv.from_local(\n",
+ " 1,\n",
+ " 0,\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "\n",
+ "print(f\"Sharded Quantized EBC: {sharded_qebc}\")"
+ ],
+ "execution_count": 44,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharded Quantized EBC: InferenceModule(\n",
+ " (ebc_): ShardedQuantEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " InferGroupedPooledEmbeddingsLookup()\n",
+ " (_output_dists): ModuleList()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " (_input_dist_module): ShardedQuantEbcInputDist()\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158555,
+ "executionStopTime": 1726000159111,
+ "serverExecutionDuration": 345.11629864573,
+ "requestMsgId": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "876807203893705"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0iBD90t3cspI",
+ "outputId": "86a17997-aa50-4427-cf2b-55f4c0aef456"
+ },
+ "source": [
+ "sharded_qebc(kjt)"
+ ],
+ "execution_count": 45,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 45
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "897037bb-9d81-4a33-aea1-de1691217d41",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "08ue1zeVcspI"
+ },
+ "source": [
+ "### Compilation\n",
+ "Now we have the optimized eager TorchRec inference model. The next step is to ensure that this model is loadable in C++, as currently it is only runnable in a Python runtime.\n",
+ "\n",
+ "The recommended method of compilation at Meta is two fold: [torch.fx tracing](https://pytorch.org/docs/stable/fx.html) (generate intermediate representation of model) and converting the result to TorchScript, where TorchScript is C++ compatible."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159118,
+ "executionStopTime": 1726000159308,
+ "serverExecutionDuration": 28.788283467293,
+ "requestMsgId": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "491668137118498"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SRzo1jljcspI",
+ "outputId": "e0f94cf0-c5cf-4480-b0d9-d5dd2abb459b"
+ },
+ "source": [
+ "from torchrec.fx import Tracer\n",
+ "\n",
+ "\n",
+ "tracer = Tracer(leaf_modules=[\"IntNBitTableBatchedEmbeddingBagsCodegen\"])\n",
+ "\n",
+ "graph = tracer.trace(sharded_qebc)\n",
+ "gm = torch.fx.GraphModule(sharded_qebc, graph)\n",
+ "\n",
+ "print(\"Graph Module Created!\")"
+ ],
+ "execution_count": 46,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Graph Module Created!\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "909178d6-4dae-45da-9c39-6827019f53a3",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159312,
+ "executionStopTime": 1726000159490,
+ "serverExecutionDuration": 2.2248737514019,
+ "requestMsgId": "909178d6-4dae-45da-9c39-6827019f53a3",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1555501808508272"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NsgzbUdHcspI",
+ "outputId": "c5b67630-19c7-46df-c0ab-216d24309603"
+ },
+ "source": [
+ "print(gm.code)"
+ ],
+ "execution_count": 47,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embeddingbag_flatten_feature_lengths\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_fx_utils__fx_marker\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embedding_kernel__unwrap_kjt\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference\")\n",
+ "\n",
+ "def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):\n",
+ " flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None\n",
+ " _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths)\n",
+ " split = flatten_feature_lengths.split([2])\n",
+ " getitem = split[0]; split = None\n",
+ " to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None\n",
+ " _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = None\n",
+ " _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None\n",
+ " getitem_1 = _unwrap_kjt[0]\n",
+ " getitem_2 = _unwrap_kjt[1]\n",
+ " getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = None\n",
+ " _tensor_constant0 = self._tensor_constant0\n",
+ " _tensor_constant1 = self._tensor_constant1\n",
+ " bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None); _tensor_constant0 = _tensor_constant1 = None\n",
+ " _tensor_constant2 = self._tensor_constant2\n",
+ " _tensor_constant3 = self._tensor_constant3\n",
+ " _tensor_constant4 = self._tensor_constant4\n",
+ " _tensor_constant5 = self._tensor_constant5\n",
+ " _tensor_constant6 = self._tensor_constant6\n",
+ " _tensor_constant7 = self._tensor_constant7\n",
+ " _tensor_constant8 = self._tensor_constant8\n",
+ " _tensor_constant9 = self._tensor_constant9\n",
+ " int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None\n",
+ " embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None\n",
+ " to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None\n",
+ " keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None\n",
+ " return keyed_tensor\n",
+ " \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159494,
+ "executionStopTime": 1726000160206,
+ "serverExecutionDuration": 540.64276814461,
+ "requestMsgId": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "978016470760577"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "CjjJLc6pcspI",
+ "outputId": "be3a9486-e4b5-43f6-aed0-711e827a0040"
+ },
+ "source": [
+ "scripted_gm = torch.jit.script(gm)\n",
+ "print(\"Scripted Graph Module Created!\")"
+ ],
+ "execution_count": 48,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Scripted Graph Module Created!\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000160212,
+ "executionStopTime": 1726000160395,
+ "serverExecutionDuration": 2.8529539704323,
+ "requestMsgId": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1020643789855657"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BWKPRaI3cspI",
+ "outputId": "273181a2-7c91-4167-e814-4a07b51c6b10"
+ },
+ "source": [
+ "print(scripted_gm.code)"
+ ],
+ "execution_count": 49,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "def forward(self,\n",
+ " kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:\n",
+ " _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths\n",
+ " _1 = __torch__.torchrec.fx.utils._fx_marker\n",
+ " _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt\n",
+ " _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference\n",
+ " flatten_feature_lengths = _0(kjt, )\n",
+ " _fx_marker = _1(\"KJT_ONE_TO_ALL_FORWARD_BEGIN\", flatten_feature_lengths, )\n",
+ " split = (flatten_feature_lengths).split([2], )\n",
+ " getitem = split[0]\n",
+ " to = (getitem).to(torch.device(\"cuda\", 0), True, None, )\n",
+ " _fx_marker_1 = _1(\"KJT_ONE_TO_ALL_FORWARD_END\", flatten_feature_lengths, )\n",
+ " _unwrap_kjt = _2(to, )\n",
+ " getitem_1 = (_unwrap_kjt)[0]\n",
+ " getitem_2 = (_unwrap_kjt)[1]\n",
+ " _tensor_constant0 = self._tensor_constant0\n",
+ " _tensor_constant1 = self._tensor_constant1\n",
+ " ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)\n",
+ " _tensor_constant2 = self._tensor_constant2\n",
+ " _tensor_constant3 = self._tensor_constant3\n",
+ " _tensor_constant4 = self._tensor_constant4\n",
+ " _tensor_constant5 = self._tensor_constant5\n",
+ " _tensor_constant6 = self._tensor_constant6\n",
+ " _tensor_constant7 = self._tensor_constant7\n",
+ " _tensor_constant8 = self._tensor_constant8\n",
+ " _tensor_constant9 = self._tensor_constant9\n",
+ " int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)\n",
+ " _4 = [int_nbit_split_embedding_codegen_lookup_function]\n",
+ " embeddings_cat_empty_rank_handle_inference = _3(_4, 1, \"cuda:0\", 6, )\n",
+ " to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device(\"cpu\"))\n",
+ " _5 = [\"product\", \"user\"]\n",
+ " _6 = [64, 64]\n",
+ " keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)\n",
+ " _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )\n",
+ " return keyed_tensor\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "9a1dda10-b9cf-4d9f-b068-51ae3ce3ffc1",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "DQiGRYOgcspI"
+ },
+ "source": [
+ "## Congrats!\n",
+ "\n",
+ "You have now gone from training a distributed RecSys model all the way to making it inference ready. https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a full example of how to load a TorchRec TorchScript model into C++ for inference."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ebXfh7oW9fHH",
+ "originalKey": "4ca6a593-9ac9-4e2f-bc9a-8c8a1887ad41",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "## More resources\n",
+ "For more information, please see our [dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) example, which includes multinode training on the criteo terabyte dataset, using Meta’s [DLRM](https://arxiv.org/abs/1906.00091)."
+ ]
+ }
+ ]
+}
diff --git a/Torchrec_Introduction.ipynb b/Torchrec_Introduction.ipynb
index 5da052875..970fbf8cc 100644
--- a/Torchrec_Introduction.ipynb
+++ b/Torchrec_Introduction.ipynb
@@ -19,44 +19,11 @@
"source": [
"## **Installation**\n",
"Requirements:\n",
- "- python >= 3.7\n",
+ "- python >= 3.9\n",
"\n",
"We highly recommend CUDA when using TorchRec. If using CUDA:\n",
- "- cuda >= 11.0\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BB2K68OYUJ_t"
- },
- "outputs": [],
- "source": [
- "# install conda to make installying pytorch with cudatoolkit 11.3 easier. \n",
- "!wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "sFYvP95xaAER"
- },
- "outputs": [],
- "source": [
- "# install pytorch with cudatoolkit 11.6\n",
- "!conda install pytorch pytorch-cuda=11.6 -c pytorch-nightly -c nvidia -y"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7iY7Uv11mJYK"
- },
- "source": [
+ "- cuda >= 11.8\n",
+ "\n",
"Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run "
]
},
@@ -64,53 +31,14 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "tUnIw-ZREQJy"
- },
- "outputs": [],
- "source": [
- "# install torchrec\n",
- "!pip3 install torchrec-nightly"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "b6EHgotRXFQh"
- },
- "source": [
- "The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the colab runtime**. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_P45pDteRcWj"
- },
- "outputs": [],
- "source": [
- "!cp /usr/local/lib/lib* /usr/lib/"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "n5_X2WOAYG3c"
- },
- "source": [
- "\\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "8cktNrh8R9rC"
+ "id": "sFYvP95xaAER"
},
"outputs": [],
"source": [
- "import sys\n",
- "sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages']"
+ "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
+ "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/nightly/cu121\n",
+ "!pip3 install torchmetrics==1.0.3\n",
+ "!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121"
]
},
{
@@ -236,11 +164,11 @@
"metadata": {},
"outputs": [],
"source": [
- "from torchrec.optim.apply_overlapped_optimizer import apply_overlapped_optimizer\n",
+ "from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\n",
"\n",
- "apply_overlapped_optimizer(\n",
- " ebc.parameters(),\n",
- " optimizer_type=torch.optim.SGD,\n",
+ "apply_optimizer_in_backward(\n",
+ " optimizer_class=torch.optim.SGD,\n",
+ " params=ebc.parameters(),\n",
" optimizer_kwargs={\"lr\": 0.02},\n",
")"
]
@@ -625,35 +553,46 @@
}
],
"metadata": {
- "accelerator": "GPU",
- "colab": {
- "background_execution": "on",
- "collapsed_sections": [],
- "machine_shape": "hm",
- "name": "Torchrec Introduction.ipynb",
- "provenance": []
- },
- "interpreter": {
- "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
- },
- "kernelspec": {
- "display_name": "Python 3.7.13 ('torchrec': conda)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
+ "custom": {
+ "cells": [],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "background_execution": "on",
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "name": "Torchrec Introduction.ipynb",
+ "provenance": []
+ },
+ "fileHeader": "",
+ "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
+ "interpreter": {
+ "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
+ },
+ "isAdHoc": false,
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('torchrec': conda)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
},
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.13"
- }
+ "nbformat": 4,
+ "nbformat_minor": 0
+ },
+ "indentAmount": 2
},
"nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 2
}
diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
index b0f6a4bde..23001de25 100644
--- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
@@ -12,7 +12,7 @@
namespace torchrec {
void BM_MixedLFULRUStrategy(benchmark::State& state) {
size_t num_ext_values = state.range(0);
- std::vector ext_values(num_ext_values);
+ std::vector ext_values(num_ext_values);
MixedLFULRUStrategy strategy;
for (auto& v : ext_values) {
diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
index 1f4183d74..f8e2a0cd2 100644
--- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
@@ -18,13 +18,13 @@ class RecordIterator {
public:
RecordIterator(Container::const_iterator begin, Container::const_iterator end)
: begin_(begin), end_(end) {}
- std::optional> operator()() {
+ std::optional operator()() {
if (begin_ == end_) {
return std::nullopt;
}
- TransformerRecord record{};
- record.global_id_ = next_global_id_++;
- record.lxu_record_ = *reinterpret_cast(&(*begin_++));
+ record_t record{};
+ record.global_id = next_global_id_++;
+ record.lxu_record = *reinterpret_cast(&(*begin_++));
return record;
}
@@ -52,8 +52,8 @@ class RandomizeMixedLXUSet {
std::uniform_int_distribution time_dist(0, max_time - 1);
for (size_t i = 0; i < n; ++i) {
MixedLFULRUStrategy::Record record{};
- record.freq_power_ = freq_dist(engine) + min_freq;
- record.time_ = time_dist(engine);
+ record.freq_power = freq_dist(engine) + min_freq;
+ record.time = time_dist(engine);
records_.emplace_back(record);
}
}
@@ -68,8 +68,9 @@ class RandomizeMixedLXUSet {
void BM_MixedLFULRUStrategyEvict(benchmark::State& state) {
RandomizeMixedLXUSet lxuSet(state.range(0), state.range(1), state.range(2));
+ MixedLFULRUStrategy strategy;
for (auto _ : state) {
- MixedLFULRUStrategy::evict(lxuSet.Iterator(), state.range(3));
+ strategy.evict(lxuSet.Iterator(), state.range(3));
}
}
diff --git a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
index 17faadb67..ae66fbfcc 100644
--- a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
@@ -13,8 +13,7 @@
namespace torchrec {
static void BM_NaiveIDTransformer(benchmark::State& state) {
- using Tag = int32_t;
- NaiveIDTransformer transformer(2e8);
+ NaiveIDTransformer transformer(2e8);
torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
torch::Tensor cache_ids = torch::empty_like(global_ids);
for (auto _ : state) {
diff --git a/benchmarks/ebc_benchmarks.py b/benchmarks/ebc_benchmarks.py
index 082f05bc0..98a0ffcbd 100644
--- a/benchmarks/ebc_benchmarks.py
+++ b/benchmarks/ebc_benchmarks.py
@@ -5,13 +5,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import argparse
import sys
from typing import List, Tuple
+# pyre-fixme[21]: Could not find module `ebc_benchmarks_utils`.
+import ebc_benchmarks_utils
import torch
-from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation
-from torchrec.github.benchmarks import ebc_benchmarks_utils
+from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection
@@ -251,5 +254,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
return parser.parse_args(argv)
-if __name__ == "__main__":
+def invoke_main() -> None:
main(sys.argv[1:])
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/ebc_benchmarks_utils.py b/benchmarks/ebc_benchmarks_utils.py
index 74f7f67d8..b15ec2b66 100644
--- a/benchmarks/ebc_benchmarks_utils.py
+++ b/benchmarks/ebc_benchmarks_utils.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import time
from typing import Dict, List, Optional, Tuple
diff --git a/contrib/dynamic_embedding/README.md b/contrib/dynamic_embedding/README.md
new file mode 100644
index 000000000..975963f35
--- /dev/null
+++ b/contrib/dynamic_embedding/README.md
@@ -0,0 +1,80 @@
+# TorchRec Dynamic Embedding
+
+This folder contains the extension to support dynamic embedding for torchrec. Specifically, this extension enable torchrec to attach an external PS, so that when local GPU embedding is not big enough, we could pull/evict embeddings from/to the PS.
+
+## Installation
+
+After install torchrec, please clone the torchrec repo and manually install the dynamic embedding:
+
+```bash
+git clone git@github.com:pytorch/torchrec.git
+cd contrib/dynamic_embedding
+python setup.py install
+```
+
+And the dynamic embedding will be installed as a separate package named `torchrec_dynamic_embedding`.
+
+Notice that for C++20 supports we recommend gcc version higher or equal to 10. Conda users could install the lastest gcc utilities with:
+
+```bash
+conda install gxx_linux-64
+```
+
+We incorporate `gtest` for the C++ code and use unittest for the python APIs. The tests make sure that the implementation does not have any precision loss. Please turn on the `TDE_WITH_TESTING` in `setup.py` to run tests. Note that for the python test, one needs to set the environment variable `TDE_MEMORY_IO_PATH` to the path of the compiled `memory_io.so`.
+
+## Usage
+
+The dynamic embedding extension has only one api, `tde.wrap`, when wrapping the dataloader and model with it, we will automatically pipeline the data processing and model training. And example of `tde.wrap` is:
+
+```python
+import torchrec_dynamic_embedding as tde
+
+class Model(nn.Module):
+ def __init__(self, config1, config2):
+ super().__init__()
+ self.emb1 = EmbeddingCollection(tables=config1, device=torch.device("meta"))
+ self.emb2 = EmbeddingCollection(tables=config2, device=torch.device("meta"))
+ ...
+
+ def forward(self, kjt1, kjt2):
+ ...
+
+m = Model(config1, config2)
+m = DistributedModelParallel(m)
+dataloader = tde.wrap(
+ "redis://127.0.0.1:6379/?prefix=model",
+ dataloader,
+ m,
+ # configs of the embedding collections in the model
+ { "emb1": config1, "emb2": config2 })
+
+for label, kjt1, kjt2 in dataloader:
+ output = m(kjt1, kjt2)
+ ...
+```
+
+The internal of `tde.wrap` is in `src/torchrec_dynamic_embedding/dataloader.py`, where we will attach hooks to the embedding tensor as well as creating the dataloader thread for pipelining.
+
+## Custom PS Extension
+
+The dynamic embedding extension supports connecting with your PS cluster. To write your own PS extension, you need to create an dynamic library (`*.so`) with these 4 functions and 1 variable:
+
+```c++
+const char* IO_type = "your-ps";
+
+void* IO_Initialize(const char* cfg);
+
+void IO_Finalize(void* instance);
+
+void IO_Pull(void* instance, IOPullParameter cfg);
+
+void IO_Push(void* instance, IOPushParameter cfg);
+```
+
+And then use the following python API to register it:
+
+```python
+torch.ops.tde.register_io(so_path)
+```
+
+After that, you could use your own PS extension by passing the corresponding URL into `tde.wrap`, where the protocol name would be the `IO_type` and the string after `"://"` will be passed to `IO_Finalize` (`"type://cfg"`).
diff --git a/contrib/dynamic_embedding/setup.py b/contrib/dynamic_embedding/setup.py
index b4a6e968d..c5e269200 100644
--- a/contrib/dynamic_embedding/setup.py
+++ b/contrib/dynamic_embedding/setup.py
@@ -1,7 +1,16 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
import os
import sys
import torch
+from setuptools import find_packages
from skbuild import setup
@@ -40,7 +49,7 @@
setup(
name="torchrec_dynamic_embedding",
package_dir={"": "src"},
- packages=["torchrec_dynamic_embedding"],
+ packages=find_packages("src"),
cmake_args=[
"-DCMAKE_BUILD_TYPE=Release",
f"-DTDE_TORCH_BASE_DIR={os.path.dirname(torch.__file__)}",
diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
index 6d6a328f3..bada56602 100644
--- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
+++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -30,7 +38,7 @@ class CachelineIDTransformerIterator {
continue;
}
TransformerRecord result{};
- result.global_id_ = -record.global_id_not_;
+ result.global_id_ = ~record.global_id_not_;
result.cache_id_ = record.cache_id_;
result.lxu_record_ = record.lxu_record_;
return result;
diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
index 5cf36a3b7..6e28fe170 100644
--- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
+++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -64,38 +72,40 @@ inline bool CachelineIDTransformer<
int64_t global_id_not = ~global_id;
CacheValue* group_begin = &cache_values_[group_id * group_size_];
int64_t k = 0;
+ int64_t empty_slot = -1;
+ int64_t cache_id = -1;
for (; k < group_size_; k++, intra_id++) {
intra_id %= group_size_;
auto& cache_value = group_begin[intra_id];
// tricky but fast :p
int64_t xor_value = cache_value.global_id_not_ ^ global_id_not;
- if (xor_value > 0) {
- continue;
- }
- int64_t cache_id;
if (xor_value == 0) { // found
cache_id = cache_value.cache_id_;
cache_value.lxu_record_ =
update(cache_value.lxu_record_, global_id, cache_id);
- } else { // empty slot
+ break;
+ } else if (xor_value < 0 && empty_slot < 0) { // empty slot
+ empty_slot = intra_id;
+ }
+ }
+ if (cache_id < 0) {
+ if (empty_slot >= 0) {
// The transformer is full.
if (C10_UNLIKELY(bitmap_.Full())) {
return false;
}
+ auto& cache_value = group_begin[empty_slot];
cache_id = bitmap_.NextFreeBit();
cache_value.global_id_not_ = global_id_not;
cache_value.cache_id_ = cache_id;
cache_value.lxu_record_ = update(std::nullopt, global_id, cache_id);
fetch(global_id, cache_id);
+ } else {
+ return false;
}
- cache_ids[i] = cache_id;
- break;
- }
-
- if (k == group_size_) {
- return false;
}
+ cache_ids[i] = cache_id;
}
return true;
}
@@ -115,14 +125,14 @@ inline void CachelineIDTransformer<
for (const int64_t global_id : global_ids) {
auto [group_id, intra_id] = FindGroupIndex(global_id);
+ int64_t global_id_not = ~global_id;
for (int64_t k = 0; k < group_size_; k++) {
int64_t offset = group_id * group_size_ + (intra_id + k) % group_size_;
auto& cache_value = cache_values_[offset];
// tricky but fast :p
- int64_t global_id_not = ~global_id;
int64_t xor_value = global_id_not ^ cache_value.global_id_not_;
if (xor_value < 0) { // not exist
- break;
+ continue;
} else if (xor_value == 0) { // found slot
bitmap_.FreeBit(cache_value.cache_id_);
cache_value.global_id_not_ = 0;
diff --git a/contrib/dynamic_embedding/src/tde/ps.cpp b/contrib/dynamic_embedding/src/tde/ps.cpp
index 4e3bdecb9..e9fcb3c25 100644
--- a/contrib/dynamic_embedding/src/tde/ps.cpp
+++ b/contrib/dynamic_embedding/src/tde/ps.cpp
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#include "tde/ps.h"
#include "tde/details/io.h"
@@ -17,9 +25,13 @@ c10::intrusive_ptr PS::Fetch(
if (cache_ids_to_fetch_or_evict_.empty()) {
return c10::make_intrusive(time, c10::intrusive_ptr());
}
- fetch_notifications_.emplace_back(time, c10::make_intrusive());
- c10::intrusive_ptr notification =
- fetch_notifications_.back().second;
+ c10::intrusive_ptr notification;
+ {
+ std::unique_lock lock_fetch(fetch_notifications_mutex_);
+ fetch_notifications_.emplace_back(
+ time, c10::make_intrusive());
+ notification = fetch_notifications_.back().second;
+ }
uint32_t num_os_ids = os_ids_.size();
io_.Pull(
table_name_,
@@ -97,9 +109,7 @@ void PS::Evict(torch::Tensor ids_to_evict) {
notification.Done();
// The shared data for all chunks.
std::vector offsets;
- offsets.reserve(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1);
- std::vector data(
- num_ids_per_chunk_ * num_os_ids * col_ids.size() * col_size_);
+ offsets.resize(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1);
for (uint32_t i = 0; i < num_ids_to_fetch; i += num_ids_per_chunk_) {
uint32_t num_ids_in_chunk = std::min(
@@ -107,22 +117,22 @@ void PS::Evict(torch::Tensor ids_to_evict) {
uint32_t data_size = num_ids_in_chunk * num_os_ids * col_ids.size();
uint32_t offsets_size = num_ids_in_chunk * num_os_ids * col_ids.size() + 1;
- offsets.clear();
- offsets.emplace_back(0);
+ std::vector all_tensors;
for (uint32_t j = i; j < i + num_ids_in_chunk; ++j) {
int64_t cache_id = cache_ids_to_fetch_or_evict_[j];
std::vector tensors = GetTensorViews(cache_id);
- for (uint32_t k : os_ids_) {
- // this cause 2 copy. is this avoidable?
- torch::Tensor tensor = tensors[k].cpu();
- // need to change this when considering col
- memcpy(
- reinterpret_cast(data.data()) + offsets.back(),
- tensor.data_ptr(),
- tensor.numel() * tensor.element_size());
- offsets.emplace_back(
- offsets.back() + tensor.numel() * tensor.element_size());
- }
+ all_tensors.insert(all_tensors.end(), tensors.begin(), tensors.end());
+ }
+ torch::Tensor data = torch::cat(all_tensors, 0).cpu();
+ TORCH_CHECK(data.numel() == data_size * col_size_);
+
+ // to prevent the original data from being prematurely recycled
+ auto data_shared_ptr = std::make_shared(data);
+
+ offsets[0] = 0;
+ for (uint32_t j = 0; j < all_tensors.size(); ++j) {
+ offsets[j + 1] =
+ offsets[j] + all_tensors[j].numel() * all_tensors[j].element_size();
}
// waiting for the Push of last chunk finishes.
notification.Wait();
@@ -133,21 +143,30 @@ void PS::Evict(torch::Tensor ids_to_evict) {
col_ids,
os_ids_,
tcb::span{
- reinterpret_cast(data.data()), data_size * sizeof(float)},
+ reinterpret_cast(data_shared_ptr->data_ptr()),
+ data_size * sizeof(float)},
tcb::span{offsets.data(), offsets_size},
- [¬ification] { notification.Done(); });
+ [¬ification, data_shared_ptr] { notification.Done(); });
}
notification.Wait();
}
void PS::SyncFetch(int64_t time) {
- while (!fetch_notifications_.empty()) {
- auto& [t, notification] = fetch_notifications_.front();
- if (t != time && time >= 0) {
+ std::unique_lock lock(
+ fetch_notifications_mutex_, std::defer_lock);
+
+ while (true) {
+ lock.lock();
+ if (fetch_notifications_.empty() ||
+ fetch_notifications_.front().first != time && time >= 0) {
+ lock.unlock();
break;
}
- notification->Wait();
+ auto notification = fetch_notifications_.front().second;
fetch_notifications_.pop_front();
+ lock.unlock();
+
+ notification->Wait();
}
}
diff --git a/contrib/dynamic_embedding/src/tde/ps.h b/contrib/dynamic_embedding/src/tde/ps.h
index aed24fdc6..cad040f4c 100644
--- a/contrib/dynamic_embedding/src/tde/ps.h
+++ b/contrib/dynamic_embedding/src/tde/ps.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -99,6 +107,7 @@ class PS : public torch::CustomClassHolder {
void Filter(const torch::Tensor& tensor);
std::mutex mu_;
+ std::mutex fetch_notifications_mutex_;
std::string table_name_;
c10::intrusive_ptr shards_;
int64_t col_size_;
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
index 1a7645fd3..96911a449 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import queue
import threading
from typing import Dict, List, Union
@@ -34,6 +41,9 @@ def transform_loop(dataloader, transform_fn, out_queue, done_event):
# save memory
del transformed_data
+ if not done_event.is_set():
+ done_event.set()
+
class DataLoaderIter:
def __init__(self, dataloader, transform_fn, num_prefetch=0):
@@ -55,6 +65,8 @@ def __del__(self):
self._done_event.set()
def _get_data(self):
+ if self._done_event.is_set():
+ raise StopIteration
if not self._transform_thread.is_alive():
raise RuntimeError("Transform thread exited unexpectedly")
data, handles = self._data_queue.get()
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
index cdf9a93a9..594213334 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
@@ -1 +1,15 @@
-from .comm import default_group, gather_kjts
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
+from .comm import (
+ broadcast_ids_to_evict,
+ broadcast_transform_result,
+ gather_global_ids,
+ scatter_cache_ids,
+)
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
index b6da3be0a..f15f3ce18 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
@@ -1,127 +1,98 @@
-from dataclasses import dataclass
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
from typing import List, Optional
import torch
import torch.distributed as dist
-from torchrec import KeyedJaggedTensor
-
__all__ = []
-_group = [None]
-
-
-@dataclass
-class GatherResult:
- values_list: List[List[torch.Tensor]]
- offset_per_key_list: List[List[torch.Tensor]]
-
-
-def default_group(group):
- if group is not None:
- return group
-
- if _group[0] is None:
- _group[0] = dist.new_group(backend="gloo")
-
- return _group[0]
-
-def gather_tensor_list(
- tensors: List[torch.Tensor], numel_lists: List[List[int]], dst=0, group=None
-) -> Optional[List[List[torch.Tensor]]]:
- rank = dist.get_rank()
+def gather_global_ids(global_ids: List[torch.Tensor], group):
world_size = dist.get_world_size()
- if dst == rank:
- if len(numel_lists) != world_size:
- raise ValueError("dst rank should know size of tensors on all ranks")
- else:
- if len(numel_lists) != 1:
- raise ValueError("non dst rank should pass its own tensor sizes")
-
- group = default_group(group)
- dtype = tensors[0].dtype
- device = tensors[0].device
+ rank = dist.get_rank()
- concated_tensor = torch.cat(tensors)
+ concat_global_ids = torch.cat(global_ids)
- if rank == dst:
- # gather can only accept same-size tensors.
- max_numel = max(sum(numel_list) for numel_list in numel_lists)
- concat_results = [
- torch.empty(max_numel, dtype=dtype, device=device)
- for _ in range(world_size)
- ]
- dist.gather(
- concated_tensor,
- gather_list=concat_results,
- dst=dst,
- group=group,
- async_op=False,
- )
+ concat_numel = torch.tensor(concat_global_ids.numel(), dtype=torch.int64)
+ concat_numel_list = [torch.tensor(0, dtype=torch.int64) for _ in range(world_size)]
+ dist.all_gather(concat_numel_list, concat_numel, group=group, async_op=False)
- results = []
- for i in range(world_size):
- splited_tensors = []
- offset = 0
- for numel in numel_lists[i]:
- splited_tensors.append(concat_results[i][offset : offset + numel])
- offset += numel
- results.append(splited_tensors)
+ max_numel = max(concat_numel_list)
+ concat_global_ids.resize_(max_numel)
- return results
+ if rank == 0:
+ concat_global_ids_list = [
+ torch.empty_like(concat_global_ids) for _ in range(world_size)
+ ]
+ dist.gather(concat_global_ids, concat_global_ids_list, 0, group, async_op=False)
+ return [
+ concat_global_ids_list[i][: concat_numel_list[i]] for i in range(world_size)
+ ], concat_numel_list
else:
- dist.gather(
- concated_tensor, gather_list=None, dst=dst, group=group, async_op=False
- )
-
- return None
+ dist.gather(concat_global_ids, None, 0, group, async_op=False)
+ return None, concat_numel_list
-def gather_kjts(kjts: List[KeyedJaggedTensor], dst=0, group=None) -> GatherResult:
+def scatter_cache_ids(
+ cache_ids_list: Optional[List[torch.Tensor]], concat_numel_list: List[int], group
+):
world_size = dist.get_world_size()
- if world_size == 1:
- return GatherResult(
- values_list=[[kjt.values()] for kjt in kjts],
- offset_per_key_list=[[kjt.offset_per_key()] for kjt in kjts],
- )
-
rank = dist.get_rank()
- group = default_group(group)
- offset_per_key_list = [torch.tensor(kjt.offset_per_key()) for kjt in kjts]
- values_list = [kjt.values() for kjt in kjts]
+ max_numel = max(concat_numel_list)
- offset_numel_list = [tensor.numel() for tensor in offset_per_key_list]
- values_numel_list = [tensor.numel() for tensor in values_list]
-
- if rank == dst:
- global_offset_numel_list = [offset_numel_list] * world_size
- offset_results = gather_tensor_list(
- offset_per_key_list,
- numel_lists=global_offset_numel_list,
- dst=dst,
- group=group,
- )
-
- global_values_numel_list = [
- [offset[-1].item() for offset in offsets] for offsets in offset_results
+ concat_cache_ids = torch.empty(max_numel, dtype=torch.int64)
+ if rank == 0:
+ concat_cache_ids_list = [concat_cache_ids] + [
+ cache_ids.resize_(max_numel)
+ for cache_ids in cache_ids_list[-world_size + 1 :]
]
-
- values_result = gather_tensor_list(
- values_list, numel_lists=global_values_numel_list, dst=dst, group=group
- )
-
- return GatherResult(
- values_list=values_result, offset_per_key_list=offset_results
- )
+ assert len(concat_cache_ids_list) == world_size
+ dist.scatter(concat_cache_ids, concat_cache_ids_list, group=group)
else:
- gather_tensor_list(
- offset_per_key_list, numel_lists=[offset_numel_list], dst=dst, group=group
+ dist.scatter(concat_cache_ids, None, group=group)
+ offset = 0
+ for cache_ids in cache_ids_list:
+ cache_ids[:] = concat_cache_ids[offset : offset + cache_ids.numel()]
+ offset += cache_ids.numel()
+
+
+def broadcast_transform_result(
+ success: bool, ids_to_fetch: Optional[torch.Tensor], group
+):
+ if dist.get_rank() == 0:
+ success_and_numel = torch.tensor(
+ [1 if success else 0, ids_to_fetch.numel()], dtype=torch.int64
)
+ dist.broadcast(success_and_numel, src=0, group=group)
+ else:
+ success_and_numel = torch.tensor([0, 0], dtype=torch.int64)
+ dist.broadcast(success_and_numel, src=0, group=group)
+ success, numel = success_and_numel.tolist()
+ success = success != 0
+ ids_to_fetch = torch.empty((numel // 2, 2), dtype=torch.int64)
+
+ if ids_to_fetch.numel() > 0:
+ dist.broadcast(ids_to_fetch, src=0, group=group)
+ return success, ids_to_fetch
- gather_tensor_list(
- values_list, numel_lists=[values_numel_list], dst=dst, group=group
- )
- return None
+def broadcast_ids_to_evict(ids, group):
+ if dist.get_rank() == 0:
+ numel = torch.tensor(ids.numel(), dtype=torch.int64)
+ dist.broadcast(numel, src=0, group=group)
+ else:
+ numel = torch.tensor(0, dtype=torch.int64)
+ dist.broadcast(numel, src=0, group=group)
+ numel = numel.item()
+ ids = torch.empty((numel // 2, 2), dtype=torch.int64)
+
+ if numel > 0:
+ dist.broadcast(ids, src=0, group=group)
+ return ids
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
index 6d7169183..735412d35 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import json
import os
@@ -33,9 +40,10 @@ def transform(self, global_ids: TensorList, cache_ids: TensorList, time: int):
"""
Transform `global_ids` and store the results in `cache_ids`.
"""
- return self._transformer.transform(
+ result = self._transformer.transform(
global_ids.tensor_list, cache_ids.tensor_list, time
)
+ return result.success, result.ids_to_fetch
def evict(self, num_to_evict):
"""
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
index ced6780df..1132d6be6 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
@@ -1,8 +1,24 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
from typing import List, Tuple, Union
import torch
+import torch.distributed as dist
from torchrec import EmbeddingBagConfig, EmbeddingConfig, KeyedJaggedTensor
+from .distributed import (
+ broadcast_ids_to_evict,
+ broadcast_transform_result,
+ gather_global_ids,
+ scatter_cache_ids,
+)
from .id_transformer import IDTransformer, TensorList
from .ps import PSCollection
@@ -27,7 +43,7 @@ def __init__(
tables: list of `Embedding(Bag)Config` or `EmbeddingBagConfig` one passed to
`Embedding(Bag)Collection`.
eviction_config: config of the eviction strategy for IDTransformers.
- transformer_config: config of the transform strategy for IDTransformers.
+ transform_config: config of the transform strategy for IDTransformers.
ps_collection: `PSCollection` of the collection, if `None`, won't do eviction or fetch.
By default, IDTransformerCollection will evict half the ids when full.
"""
@@ -46,19 +62,61 @@ def __init__(
for feature_name in config.feature_names:
if feature_name in feature_names:
raise ValueError(f"Shared feature not allowed yet.")
- self._transformers.append(
- IDTransformer(
+ # only rank 0 will have the id transformer
+ # and other ranks will gather their to rank 0.
+ if dist.get_rank() == 0:
+ transformer = IDTransformer(
num_embedding=config.num_embeddings,
eviction_config=eviction_config,
transform_config=transform_config,
)
- )
+ else:
+ transformer = None
+ self._transformers.append(transformer)
self._feature_names: List[List[str]] = [
config.feature_names for config in tables
]
self._ever_evicted = False
self._time = 0
+ if dist.get_world_size() > 1:
+ self._pg = dist.new_group(backend="gloo")
+ self._stream = torch.cuda.Stream()
+
+ def _transform(
+ self, transformer, global_ids: List[torch.Tensor], cache_ids: List[torch.Tensor]
+ ):
+ with torch.cuda.stream(self._stream):
+ total_numel = sum([tensor.numel() for tensor in global_ids])
+ if total_numel > 1e6:
+ all_tensor = torch.cat(global_ids).to("cuda:0")
+ unique_all_tensor, index = torch.unique(all_tensor, return_inverse=True)
+ unique_all_tensor = unique_all_tensor.to("cpu")
+ all_cache = torch.empty_like(unique_all_tensor)
+ success, ids_to_fetch = transformer.transform(
+ TensorList([unique_all_tensor]),
+ TensorList([all_cache]),
+ self._time,
+ )
+ del all_tensor
+ all_tensor = torch.take(all_cache.to("cuda:0"), index)
+ offset = 0
+ for tensor in cache_ids:
+ numel = tensor.numel()
+ tensor.copy_(all_tensor[offset : offset + numel])
+ offset += numel
+ assert (
+ total_numel == offset
+ ), f"total_numel not equal offset, {total_numel} vs {offset}"
+ else:
+ # broadcast result
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
+ self._time,
+ )
+ return success, ids_to_fetch
+
def transform(
self, global_features: KeyedJaggedTensor
) -> Tuple[KeyedJaggedTensor, List[torch.classes.tde.FetchHandle]]:
@@ -92,47 +150,128 @@ def transform(
for idx in feature_indices
]
- result = transformer.transform(
- TensorList(global_ids), TensorList(cache_ids), self._time
- )
- if self._ps_collection is not None:
- table_name = self._table_names[i]
- ps = self._ps_collection[table_name]
- if result.ids_to_fetch.numel() > 0:
- handle = ps.fetch(
- result.ids_to_fetch,
- self._time,
- self._ever_evicted,
- self._configs[i].get_weight_init_min(),
- self._configs[i].get_weight_init_max(),
- )
- fetch_handles.append(handle)
- if not result.success:
- # TODO(zilinzhu): make this configurable
- ids_to_evict = transformer.evict(transformer._num_embedding // 2)
- ps.evict(ids_to_evict)
- self._ever_evicted = True
-
- # retry after eviction.
- result = transformer.transform(
- TensorList(global_ids), TensorList(cache_ids), self._time
+ if dist.get_world_size() > 1:
+ concat_global_ids, concat_numel_list = gather_global_ids(
+ global_ids, self._pg
+ )
+ if dist.get_rank() == 0:
+ global_ids = global_ids + concat_global_ids[1:]
+ cache_ids = cache_ids + [
+ torch.empty_like(tensor) for tensor in concat_global_ids[1:]
+ ]
+
+ success, ids_to_fetch = self._transform(
+ transformer, global_ids, cache_ids
)
- if not result.success:
- raise RuntimeError(
- "Failed to transform global ids after eviction. "
- f"Maybe the num_embedding of table {table_name} is too small?"
+ else:
+ success, ids_to_fetch = True, None
+ success, ids_to_fetch = broadcast_transform_result(
+ success, ids_to_fetch, self._pg
+ )
+
+ if self._ps_collection is not None:
+ table_name = self._table_names[i]
+ ps = self._ps_collection[table_name]
+ if ids_to_fetch.numel() > 0:
+ handle = ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
)
- if result.ids_to_fetch is not None:
- fetch_handles.append(
- ps.fetch(
- result.ids_to_fetch,
+ fetch_handles.append(handle)
+ if not success:
+ # TODO(zilinzhu): make this configurable
+ # broadcast ids_to_evict
+ if dist.get_rank() == 0:
+ ids_to_evict = transformer.evict(
+ transformer._num_embedding // 2
+ )
+ else:
+ ids_to_evict = None
+ ids_to_evict = broadcast_ids_to_evict(ids_to_evict, self._pg)
+
+ ps.evict(ids_to_evict)
+ self._ever_evicted = True
+
+ # retry after eviction.
+ # broadcast result
+ if dist.get_rank() == 0:
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
self._time,
- self._ever_evicted,
- self._configs[i].get_weight_init_min(),
- self._configs[i].get_weight_init_max(),
)
+ else:
+ success, ids_to_fetch = True, None
+ success, ids_to_fetch = broadcast_transform_result(
+ success, ids_to_fetch, self._pg
)
+ if not success:
+ raise RuntimeError(
+ "Failed to transform global ids after eviction. "
+ f"Maybe the num_embedding of table {table_name} is too small?"
+ )
+ if ids_to_fetch.numel() > 0:
+ fetch_handles.append(
+ ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ )
+
+ scatter_cache_ids(cache_ids, concat_numel_list, self._pg)
+ else:
+ success, ids_to_fetch = self._transform(
+ transformer, global_ids, cache_ids
+ )
+ if self._ps_collection is not None:
+ table_name = self._table_names[i]
+ ps = self._ps_collection[table_name]
+ if ids_to_fetch.numel() > 0:
+ handle = ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ fetch_handles.append(handle)
+ if not success:
+ # TODO(zilinzhu): make this configurable
+ ids_to_evict = transformer.evict(
+ transformer._num_embedding // 2
+ )
+ ps.evict(ids_to_evict)
+ self._ever_evicted = True
+
+ # retry after eviction.
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
+ self._time,
+ )
+ if not success:
+ raise RuntimeError(
+ "Failed to transform global ids after eviction. "
+ f"Maybe the num_embedding of table {table_name} is too small?"
+ )
+ if ids_to_fetch is not None:
+ fetch_handles.append(
+ ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ )
+
cache_values = KeyedJaggedTensor(
keys=global_features.keys(),
values=cache_values,
@@ -147,5 +286,18 @@ def save(self):
return
for i, transformer in enumerate(self._transformers):
table_name = self._table_names[i]
- ids = transformer.save()
+ if dist.get_world_size() > 1:
+ if dist.get_rank() == 0:
+ ids = transformer.save()
+ numel = torch.tensor(ids.numel())
+ dist.broadcast(numel, src=0, group=self._pg)
+ dist.broadcast(ids, src=0, group=self._pg)
+ else:
+ numel = torch.tensor(0)
+ dist.broadcast(numel, src=0, group=self._pg)
+ ids = torch.empty((numel // 2, 2), dtype=torch.int64)
+ dist.broadcast(ids, src=0, group=self._pg)
+ else:
+ ids = transformer.save()
+
self._ps_collection[table_name].evict(ids)
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
index c26b2ad4f..9c888a15f 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
@@ -57,8 +57,8 @@ def __init__(
configs or embeddingbag configs. The plan of `module` should contain the module path
in `configs_dict`.
eviction_config: configuration for eviction policy. Default is `{"type": "mixed_lru_lfu"}`
- transformer_config: configuration for the transformer. Default is `{"type": "naive"}`
- parallel: Whether the IDTransformerCollections will run paralell. When set to True,
+ transform_config: configuration for the transformer. Default is `{"type": "naive"}`
+ parallel: Whether the IDTransformerCollections will run parallel. When set to True,
IDTransformerGroup will start a thread for each IDTransformerCollection.
Example:
@@ -166,12 +166,6 @@ def __contains__(self, path):
"""
return path in self._id_transformer_collections
- def __contains__(self, path):
- """
- Check if there is transformer for the path.
- """
- return path in self._id_transformer_collections
-
def __del__(self):
"""
Stop the parallel threads
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
index 690e89b97..763f3e8ef 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
@@ -1,10 +1,15 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import os
from typing import Callable, Dict, List, Tuple, Union
import torch
-import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
-from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ParameterSharding
from .tensor_list import TensorList
@@ -54,7 +59,6 @@ def __init__(
# This assumes all shard have the same column size.
col_size = shard.tensor.shape[1]
elif isinstance(tensors[0], torch.Tensor):
- tensors
shards.append(
0,
0,
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
index fc11cce9a..591eda4c3 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
@@ -1,4 +1,9 @@
-import json
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import os
from typing import List
diff --git a/contrib/dynamic_embedding/tests/test_integral_precision.py b/contrib/dynamic_embedding/tests/test_integral_precision.py
index 842a2b72d..965bb092c 100644
--- a/contrib/dynamic_embedding/tests/test_integral_precision.py
+++ b/contrib/dynamic_embedding/tests/test_integral_precision.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import unittest
import torch
@@ -14,6 +21,7 @@
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec_dynamic_embedding.id_transformer_group import IDTransformerGroup
from utils import init_dist, register_memory_io
@@ -93,7 +101,7 @@ def get_dmp(model):
model = DMP(module=model, device=device, plan=plan, sharders=sharders)
dense_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adam(params, lr=1e-1),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
diff --git a/contrib/dynamic_embedding/tests/test_ps_collection.py b/contrib/dynamic_embedding/tests/test_ps_collection.py
index a7fffe1e0..d293749e1 100644
--- a/contrib/dynamic_embedding/tests/test_ps_collection.py
+++ b/contrib/dynamic_embedding/tests/test_ps_collection.py
@@ -1,4 +1,9 @@
-import os
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import unittest
import torch
diff --git a/contrib/dynamic_embedding/tools/before_linux_build.sh b/contrib/dynamic_embedding/tools/before_linux_build.sh
index 7cf1d154f..527ee2128 100755
--- a/contrib/dynamic_embedding/tools/before_linux_build.sh
+++ b/contrib/dynamic_embedding/tools/before_linux_build.sh
@@ -1,13 +1,20 @@
#!/usr/bin/env bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
set -xe
distro=rhel7
arch=x86_64
-CUDA_VERSION="${CUDA_VERSION:-11.6}"
+CUDA_VERSION="${CUDA_VERSION:-11.8}"
CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}')
CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}')
+yum install -y yum-utils
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo
yum install -y \
cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \
diff --git a/contrib/dynamic_embedding/tools/build_wheels.sh b/contrib/dynamic_embedding/tools/build_wheels.sh
index 51cd42500..3647547de 100755
--- a/contrib/dynamic_embedding/tools/build_wheels.sh
+++ b/contrib/dynamic_embedding/tools/build_wheels.sh
@@ -6,6 +6,8 @@ export CIBW_BEFORE_BUILD="tools/before_linux_build.sh"
# all kinds of CPython.
export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"}
+export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"}
+
# Do not auditwheels since tde uses torch's shared libraries.
export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}"
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 000000000..fb00038dc
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,23 @@
+Docs
+==========
+
+
+## Building the docs
+
+To build and preview the docs run the following commands:
+
+```bash
+cd docs
+pip3 install -r requirements.txt
+make html
+python3 -m http.server 8082 --bind ::
+```
+
+Now you should be able to view the docs in your browser at the link provided in your terminal.
+
+To reload the preview after making changes, rerun:
+
+```bash
+make html
+python3 -m http.server 8082 --bind ::
+```
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 851694777..96d50ad98 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,4 +1,7 @@
-sphinx
+sphinx==5.0.0
+pyre-extensions
+sphinx-design
+sphinx_copybutton
# torch
# PyTorch Theme
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css
new file mode 100644
index 000000000..55fee00b4
--- /dev/null
+++ b/docs/source/_static/css/custom.css
@@ -0,0 +1,114 @@
+/**
+* Copyright (c) Meta Platforms, Inc. and affiliates.
+* All rights reserved.
+*
+* This source code is licensed under the BSD-style license found in the
+* LICENSE file in the root directory of this source tree.
+*/
+
+/* sphinx-design styles for cards/tabs */
+
+:root {
+ --sd-color-info: #ee4c2c;
+ --sd-color-info-highlight: #ee4c2c;
+ --sd-color-primary: #6c6c6d;
+ --sd-color-primary-highligt: #f3f4f7;
+ --sd-color-card-border-hover: #ee4c2c;
+ --sd-color-card-border: #f3f4f7;
+ --sd-color-card-background: #fff;
+ --sd-color-card-text: inherit;
+ --sd-color-card-header: transparent;
+ --sd-color-card-footer: transparent;
+ --sd-color-tabs-label-active: #ee4c2c;
+ --sd-color-tabs-label-hover: #ee4c2c;
+ --sd-color-tabs-label-inactive: #6c6c6d;
+ --sd-color-tabs-underline-active: #ee4c2c;
+ --sd-color-tabs-underline-hover: #fabdbd;
+ --sd-color-tabs-underline-inactive: transparent;
+ --sd-color-tabs-overline: rgb(222, 222, 222);
+ --sd-color-tabs-underline: rgb(222, 222, 222);
+}
+
+.sd-text-info {
+ color: #ee4c2c;
+}
+
+.sd-card-img-top {
+ background: #ee4c2c;
+ height: 5px !important;
+}
+
+.sd-card {
+ position: relative;
+ background-color: #fff;
+ opacity: 1.0;
+ border-radius: 0px;
+ width: 30%;
+ border: none;
+ padding-bottom: 0px;
+}
+
+.sd-card-img {
+ opacity: 0.5;
+ width: 200px;
+ padding: 0px;
+}
+
+.sd-card-img:hover {
+ opacity: 1.0;
+ background-color: #f3f4f7;
+}
+
+
+.sd-card:after {
+ display: block;
+ opacity: 1;
+ content: '';
+ border-bottom: solid 1px #ee4c2c;
+ background-color: #fff;
+ transform: scaleX(0);
+ transition: transform .250s ease-in-out;
+ transform-origin: 0% 50%;
+}
+
+.sd-card:hover {
+ background-color: #fff;
+ opacity: 1;
+ border-top: 1px solid #f3f4f7;
+ border-left: 1px solid #f3f4f7;
+ border-right: 1px solid #f3f4f7;
+}
+
+.sd-card:hover:after {
+ transform: scaleX(1);
+}
+
+.card-prerequisites:hover {
+ transition: none;
+ border: none;
+}
+
+.card-prerequisites:hover:after {
+ transition: none;
+ transform: none;
+}
+
+.card-prerequisites:after {
+ display: block;
+ content: '';
+ border-bottom: none;
+ background-color: #fff;
+ transform: none;
+ transition: none;
+ transform-origin: none;
+}
+
+details.sd-dropdown {
+ font-weight: 300;
+ width: auto;
+}
+
+.center-content {
+ display: flex;
+ justify-content: center;
+}
diff --git a/docs/source/_static/img/card-background.svg b/docs/source/_static/img/card-background.svg
new file mode 100644
index 000000000..d97193223
--- /dev/null
+++ b/docs/source/_static/img/card-background.svg
@@ -0,0 +1,13 @@
+
+
diff --git a/docs/source/_static/img/full_training_loop.png b/docs/source/_static/img/full_training_loop.png
new file mode 100644
index 000000000..221e7c387
Binary files /dev/null and b/docs/source/_static/img/full_training_loop.png differ
diff --git a/docs/source/_static/img/fused_backward_optimizer.png b/docs/source/_static/img/fused_backward_optimizer.png
new file mode 100644
index 000000000..3c3d3593e
Binary files /dev/null and b/docs/source/_static/img/fused_backward_optimizer.png differ
diff --git a/docs/source/_static/img/fused_embedding_tables.png b/docs/source/_static/img/fused_embedding_tables.png
new file mode 100644
index 000000000..6874e4d75
Binary files /dev/null and b/docs/source/_static/img/fused_embedding_tables.png differ
diff --git a/docs/source/_static/img/model_parallel.png b/docs/source/_static/img/model_parallel.png
new file mode 100644
index 000000000..81d18e9e1
Binary files /dev/null and b/docs/source/_static/img/model_parallel.png differ
diff --git a/docs/source/_static/img/sharding.png b/docs/source/_static/img/sharding.png
new file mode 100644
index 000000000..653eb7cdb
Binary files /dev/null and b/docs/source/_static/img/sharding.png differ
diff --git a/docs/source/_static/img/torchrec_forward.png b/docs/source/_static/img/torchrec_forward.png
new file mode 100644
index 000000000..156f13e8f
Binary files /dev/null and b/docs/source/_static/img/torchrec_forward.png differ
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
new file mode 100644
index 000000000..dbc08744a
--- /dev/null
+++ b/docs/source/_templates/layout.html
@@ -0,0 +1,10 @@
+{% extends "!layout.html" %}
+
+{% block footer %}
+{{ super() }}
+
+
+
+{% endblock %}
diff --git a/docs/source/concepts.rst b/docs/source/concepts.rst
new file mode 100644
index 000000000..b1d7f3776
--- /dev/null
+++ b/docs/source/concepts.rst
@@ -0,0 +1,311 @@
+.. meta::
+ :description: TorchRec Concepts
+ :keywords: recommendation systems, sharding, distributed training, torchrec, embedding bags, embeddings, keyedjaggedtensor, row wise, table wise, column wise, table row wise, planner, sharder
+
+###################
+TorchRec Concepts
+###################
+
+In this section, we will learn about the key concepts of TorchRec,
+designed to optimize large-scale recommendation systems using PyTorch.
+We will learn how each concept works in detail and review how it is used
+with the rest of TorchRec.
+
+TorchRec has specific input/output data types of its modules to
+efficiently represent sparse features, including:
+
+- **JaggedTensor:** a wrapper around the lengths/offsets and values
+ tensors for a singular sparse feature.
+- **KeyedJaggedTensor:** efficiently represent multiple sparse
+ features, can think of it as multiple ``JaggedTensor``\s.
+- **KeyedTensor:** a wrapper around ``torch.Tensor`` that allows access
+ to tensor values through keys.
+
+With the goal of high performance and efficiency, the canonical
+``torch.Tensor`` is highly inefficient for representing sparse data.
+TorchRec introduces these new data types because they provide efficient
+storage and representation of sparse input data. As you will see later
+on, the ``KeyedJaggedTensor`` makes communication of input data in a
+distributed environment very efficient leading to one of the key
+performance advantages that TorchRec provides.
+
+In the end-to-end training loop, TorchRec comprises of the following
+main components:
+
+- **Planner:** Takes in the configuration of embedding tables,
+ environment setup, and generates an optimized sharding plan for the
+ model.
+
+- **Sharder:** Shards model according to sharding plan with different
+ sharding strategies including data-parallel, table-wise, row-wise,
+ table-wise-row-wise, column-wise, and table-wise-column-wise
+ sharding.
+
+- **DistributedModelParallel:** Combines sharder, optimizer, and
+ provides an entry point into the training the model in a distributed
+ manner.
+
+**************
+JaggedTensor
+**************
+
+A ``JaggedTensor`` represents a sparse feature through lengths, values,
+and offsets. It is called "jagged" because it efficiently represents
+data with variable-length sequences. In contrast, a canonical
+``torch.Tensor`` assumes that each sequence has the same length, which
+is often not the case with real world data. A ``JaggedTensor``
+facilitates the representation of such data without padding making it
+highly efficient.
+
+Key Components:
+
+- ``Lengths``: A list of integers representing the number of elements
+ for each entity.
+
+- ``Offsets``: A list of integers representing the starting index of
+ each sequence in the flattened values tensor. These provide an
+ alternative to lengths.
+
+- ``Values``: A 1D tensor containing the actual values for each entity,
+ stored contiguously.
+
+Here is a simple example demonstrating how each of the components would
+look like:
+
+.. code:: python
+
+ # User interactions:
+ # - User 1 interacted with 2 items
+ # - User 2 interacted with 3 items
+ # - User 3 interacted with 1 item
+ lengths = [2, 3, 1]
+ offsets = [0, 2, 5] # Starting index of each user's interactions
+ values = torch.Tensor([101, 102, 201, 202, 203, 301]) # Item IDs interacted with
+ jt = JaggedTensor(lengths=lengths, values=values)
+ # OR
+ jt = JaggedTensor(offsets=offsets, values=values)
+
+*******************
+KeyedJaggedTensor
+*******************
+
+A ``KeyedJaggedTensor`` extends the functionality of ``JaggedTensor`` by
+introducing keys (which are typically feature names) to label different
+groups of features, for example, user features and item features. This
+is the data type used in ``forward`` of ``EmbeddingBagCollection`` and
+``EmbeddingCollection`` as they are used to represent multiple features
+in a table.
+
+A ``KeyedJaggedTensor`` has an implied batch size, which is the number
+of features divided by the length of ``lengths`` tensor. The example
+below has a batch size of 2. Similar to a ``JaggedTensor``, the
+``offsets`` and ``lengths`` function in the same manner. You can also
+access the ``lengths``, ``offsets``, and ``values`` of a feature by
+accessing the key from the ``KeyedJaggedTensor``.
+
+.. code:: python
+
+ keys = ["user_features", "item_features"]
+ # Lengths of interactions:
+ # - User features: 2 users, with 2 and 3 interactions respectively
+ # - Item features: 2 items, with 1 and 2 interactions respectively
+ lengths = [2, 3, 1, 2]
+ values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201])
+ # Create a KeyedJaggedTensor
+ kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
+ # Access the features by key
+ print(kjt["user_features"])
+ # Outputs user features
+ print(kjt["item_features"])
+
+*********
+Planner
+*********
+
+The TorchRec planner helps determine the best sharding configuration for
+a model. It evaluates multiple possibilities for sharding embedding
+tables and optimizes for performance. The planner performs the
+following:
+
+- Assesses the memory constraints of the hardware.
+- Estimates compute requirements based on memory fetches, such as
+ embedding lookups.
+- Addresses data-specific factors.
+- Considers other hardware specifics, such as bandwidth, to generate an
+ optimal sharding plan.
+
+To ensure accurate consideration of these factors, the Planner can
+incorporate data about the embedding tables, constraints, hardware
+information, and topology to help in generating an optimal plan.
+
+*****************************
+Sharding of EmbeddingTables
+*****************************
+
+TorchRec sharder provides multiple sharding strategies for various use
+cases, we outline some of the sharding strategies and how they work as
+well as their benefits and limitations. Generally, we recommend using
+the TorchRec planner to generate a sharding plan for you as it will find
+the optimal sharding strategy for each embedding table in your model.
+
+Each sharding strategy determines how to do the table split, whether the
+table should be cut up and how, whether to keep one or a few copies of
+some tables, and so on. Each piece of the table from the outcome of
+sharding, whether it is one embedding table or part of it, is referred
+to as a shard.
+
+.. figure:: _static/img/sharding.png
+ :alt: Visualizing the difference of sharding types offered in TorchRec
+ :align: center
+
+ *Figure 1: Visualizing the placement of table shards under different sharding schemes offered in TorchRec*
+
+Here is the list of all sharding types available in TorchRec:
+
+- Table-wise (TW): as the name suggests, embedding table is kept as a
+ whole piece and placed on one rank.
+
+- Column-wise (CW): the table is split along the ``emb_dim`` dimension,
+ for example, ``emb_dim=256`` is split into 4 shards: ``[64, 64, 64,
+ 64]``.
+
+- Row-wise (RW): the table is split along the ``hash_size`` dimension,
+ usually split evenly among all the ranks.
+
+- Table-wise-row-wise (TWRW): table is placed on one host, split
+ row-wise among the ranks on that host.
+
+- Grid-shard (GS): a table is CW sharded and each CW shard is placed
+ TWRW on a host.
+
+- Data parallel (DP): each rank keeps a copy of the table.
+
+Once sharded, the modules are converted to sharded versions of
+themselves, known as ``ShardedEmbeddingCollection`` and
+``ShardedEmbeddingBagCollection`` in TorchRec. These modules handle the
+communication of input data, embedding lookups, and gradients.
+
+****************************************************
+Distributed Training with TorchRec Sharded Modules
+****************************************************
+
+With many sharding strategies available, how do we determine which one
+to use? There is a cost associated with each sharding scheme, which in
+conjunction with model size and number of GPUs determines which sharding
+strategy is best for a model.
+
+Without sharding, where each GPU keeps a copy of the embedding table
+(DP), the main cost is computation in which each GPU looks up the
+embedding vectors in its memory in the forward pass and updates the
+gradients in the backward pass.
+
+With sharding, there is an added communication cost: each GPU needs to
+ask the other GPUs for embedding vector lookup and communicate the
+gradients computed as well. This is typically referred to as ``all2all``
+communication. In TorchRec, for input data on a given GPU, we determine
+where the embedding shard for each part of the data is located and send
+it to the target GPU. That target GPU then returns the embedding vectors
+back to the original GPU. In the backward pass, the gradients are sent
+back to the target GPU and the shards are updated accordingly with the
+optimizer.
+
+As described above, sharding requires us to communicate the input data
+and embedding lookups. TorchRec handles this in three main stages, we
+will refer to this as the sharded embedding module forward that is used
+in training and inference of a TorchRec model:
+
+- Feature All to All/Input distribution (``input_dist``)
+
+ - Communicate input data (in the form of a ``KeyedJaggedTensor``) to
+ the appropriate device containing relevant embedding table shard
+
+- Embedding Lookup
+
+ - Lookup embeddings with new input data formed after feature all to
+ all exchange
+
+- Embedding All to All/Output Distribution (``output_dist``)
+
+ - Communicate embedding lookup data back to the appropriate device
+ that asked for it (in accordance with the input data the device
+ received)
+
+- The backward pass does the same operations but in reverse order.
+
+The diagram below demonstrates how it works:
+
+.. figure:: _static/img/torchrec_forward.png
+ :alt: Visualizing the forward pass including the input_dist, lookup, and output_dist of a sharded TorchRec module
+ :align: center
+
+ *Figure 2: Forward pass of a table wise sharded table including the input_dist, lookup, and output_dist of a sharded TorchRec module*
+
+**************************
+DistributedModelParallel
+**************************
+
+All of the above culminates into the main entrypoint that TorchRec uses
+to shard and integrate the plan. At a high level,
+``DistributedModelParallel`` does the following:
+
+- Initializes the environment by setting up process groups and
+ assigning device type.
+
+- Uses default sharders if no sharders are provided, the default includes
+ ``EmbeddingBagCollectionSharder``.
+
+- Takes in the provided sharding plan, if none is provided, it
+ generates one.
+
+- Creates a sharded version of modules and replaces the original
+ modules with them, for example, converts ``EmbeddingCollection`` to
+ ``ShardedEmbeddingCollection``.
+
+- By default, wraps the ``DistributedModelParallel`` with
+ ``DistributedDataParallel`` to make the module both model and data
+ parallel.
+
+***********
+Optimizer
+***********
+
+TorchRec modules provide a seamless API to fuse the backwards pass and
+optimizer step in training, providing a significant optimization in
+performance and decreasing the memory used, alongside granularity in
+assigning distinct optimizers to distinct model parameters.
+
+.. figure:: _static/img/fused_backward_optimizer.png
+ :alt: Visualizing fusing of optimizer in backward to update sparse embedding table
+ :align: center
+
+ *Figure 3: Fusing embedding backward with sparse optimizer*
+
+***********
+Inference
+***********
+
+Inference environments are different from training, they are very
+sensitive to performance and the size of the model. There are two key
+differences TorchRec inference optimizes for:
+
+- **Quantization:** inference models are quantized for lower latency
+ and reduced model size. This optimization lets us use as few devices
+ as possible for inference to minimize latency.
+
+- **C++ environment:** to minimize latency even further, the model is
+ ran in a C++ environment.
+
+TorchRec provides the following to convert a TorchRec model into being
+inference ready:
+
+- APIs for quantizing the model, including optimizations automatically
+ with FBGEMM TBE
+- Sharding embeddings for distributed inference
+- Compiling the model to TorchScript (compatible in C++)
+
+*********
+See Also
+*********
+
+- `TorchRec Interactive Notebook using the concepts
+ `_
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 4011be287..0f64cefe7 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -30,19 +30,34 @@
# -- Project information -----------------------------------------------------
project = "TorchRec"
-copyright = "2022, Meta"
+copyright = "2024, Meta"
author = "Meta"
+try:
+ # pyre-ignore
+ version = "1.0.0" # TODO: Hardcode stable version for now
+except Exception:
+ # when run internally, we don't have a version yet
+ version = "0.0.0"
# The full version, including alpha/beta/rc tags
-release = "0.0.1"
-
+# First 3 as format is 0.x.x.*
+release = ".".join(version.split(".")[:3])
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
-extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"]
+extensions = [
+ "sphinx.ext.napoleon",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.doctest",
+ "sphinx.ext.intersphinx",
+ "sphinx_copybutton",
+ "sphinx.ext.mathjax",
+ "sphinx_design",
+]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -62,7 +77,21 @@
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#
+html_theme_options = {
+ "pytorch_project": "torchrec",
+ "display_version": True,
+ "logo_only": True,
+ "collapse_navigation": False,
+ "includehidden": True,
+}
+
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
+
+html_css_files = ["css/custom.css"]
diff --git a/docs/source/datatypes-api-reference.rst b/docs/source/datatypes-api-reference.rst
new file mode 100644
index 000000000..8454a15f2
--- /dev/null
+++ b/docs/source/datatypes-api-reference.rst
@@ -0,0 +1,22 @@
+Data Types
+-------------------
+
+
+TorchRec contains data types for representing embedding, otherwise known as sparse features.
+Sparse features are typically indices that are meant to be fed into embedding tables. For a given
+batch, the number of embedding lookup indices are variable. Therefore, there is a need for a **jagged**
+dimension to represent the variable amount of embedding lookup indices for a batch.
+
+This section covers the classes for the 3 TorchRec data types for representing sparse features:
+**JaggedTensor**, **KeyedJaggedTensor**, and **KeyedTensor**.
+
+.. automodule:: torchrec.sparse.jagged_tensor
+
+.. autoclass:: JaggedTensor
+ :members:
+
+.. autoclass:: KeyedJaggedTensor
+ :members:
+
+.. autoclass:: KeyedTensor
+ :members:
diff --git a/docs/source/high-level-arch.rst b/docs/source/high-level-arch.rst
new file mode 100644
index 000000000..2c2263b6c
--- /dev/null
+++ b/docs/source/high-level-arch.rst
@@ -0,0 +1,129 @@
+.. meta::
+ :description: TorchRec High Level Architecture
+ :keywords: recommendation systems, sharding, distributed training, torchrec, architecture
+
+##################################
+ TorchRec High Level Architecture
+##################################
+
+In this section, you will learn about the high-level architecture of
+TorchRec, designed to optimize large-scale recommendation systems using
+PyTorch. You will learn how TorchRec employs model parallelism to
+distribute complex models across multiple GPUs, enhancing memory
+management and GPU utilization, as well as get introduced to TorchRec's
+base components and sharding strategies.
+
+In effect, TorchRec provides parallelism primitives allowing hybrid data
+parallelism/model parallelism, embedding table sharding, planner to
+generate sharding plans, pipelined training, and more.
+
+****************************************************
+ TorchRec's Parallelism Strategy: Model Parallelism
+****************************************************
+
+As modern deep learning models have scaled, distributed deep learning
+has become required to successfully train models in sufficient time. In
+this paradigm, two main approaches have been developed: data parallelism
+and model parallelism. TorchRec focuses on the latter for the sharding
+of embedding tables.
+
+.. figure:: _static/img/model_parallel.png
+ :alt: Visualizing the difference of sharding a model in model parallel or data parallel approach
+ :align: center
+
+ *Figure 1. Comparison between model parallelism and data parallelism approach*
+
+As you can see in the diagram above, model parallelism and data
+parallelism are two approaches to distribute workloads across multiple
+GPUs,
+
+- **Model Parallelism**
+
+ - Divide the model into segments and distribute them across GPUs
+ - Each segment processes data independently
+ - Suitable for large models that don't fit on a single GPU
+
+- **Data Parallel**
+
+ - Distribute the copies of entire model on each GPU
+ - Each GPU processes a subset of the data and contributes to the
+ overall computation
+ - Effecive for models that fit on single GPU but need to handle
+ large datasets
+
+- **Benefits of Model Parallelism**
+
+ - Optimizes memory usage and computational efficiency for large
+ models
+ - Particularly beneficial for recommendation systems with large
+ embedding tables
+ - Enables parallel computation of embeddings in DLRM-type
+ architectures
+
+******************
+ Embedding Tables
+******************
+
+For TorchRec to figure out what to recommend, we need to be able to
+represent entities and their relationships, this is what embeddings are
+used for. Embeddings are vectors of real numbers in a high dimensional
+space used to represent meaning in complex data like words, images, or
+users. An embedding table is an aggregation of multiple embeddings into
+one matrix. Most commonly, embedding tables are represented as a 2D
+matrix with dimensions (B, N).
+
+- *B* is the number of embeddings stored by the table
+- *N* is number of dimensions per embedding.
+
+Each of *B* can also be referred to as an ID (representing information
+such as movie title, user, ad, and so on), when accessing an ID we are
+returned the corresponding embedding vector which has size of embedding
+dimension *N*.
+
+There is also the choice of pooling embeddings, often, we’re looking up
+multiple rows for a given feature which gives rise to the question of
+what we do with looking up multiple embedding vectors. Pooling is a
+common technique where we combine the embedding vectors, usually through
+sum or mean of the rows, to produce one embedding vector. This is the
+main difference between the PyTorch ``nn.Embedding`` and
+``nn.EmbeddingBag``.
+
+PyTorch represents embeddings through ``nn.Embedding`` and
+``nn.EmbeddingBag``. Building on these modules, TorchRec introduces
+``EmbeddingCollection`` and ``EmbeddingBagCollection``, which are
+collections of the corresponding PyTorch modules. This extension enables
+TorchRec to batch tables and perform lookups on multiple embeddings in a
+single kernel call, improving efficiency.
+
+Here is the end-to-end flow diagram that describes how embeddings are
+used in the training process for recommendation models:
+
+.. figure:: _static/img/full_training_loop.png
+ :alt: Demonstrating the full training loop from embedding lookup to optimizer update in backward
+ :align: center
+
+ *Figure 2. TorchRec End-to-end Embedding Flow*
+
+In the diagram above, we show the general TorchRec end to end embedding
+lookup process,
+
+- In the forward pass we do the embedding lookup and pooling
+- In the backward pass we compute the gradients of the output lookups
+ and pass them into the optimizer to update the embedding tables
+
+**Note here, the embeddings gradients are grayed out since we do not
+fully materialize these into memory and instead fuse them with the
+optimizer update. This results in a significant memory reduction which
+we detail later in the optimizer concepts section.**
+
+We recommend going through the TorchRec Concepts page to get a
+understanding of the fundamentals of how everything ties together
+end-to-end. It contains lots of useful information to get the most out
+of TorchRec.
+
+**********
+ See also
+**********
+
+- `What is Distributed Data Parallel (DDP) Tutorial
+ `_
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 09f3a6f2c..c6fa49282 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -3,51 +3,109 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
+.. meta::
+ :description: TorchRec documentation homepage
+ :keywords: recommendation systems, sharding, distributed training
+
Welcome to the TorchRec documentation!
======================================
-TorchRec is a PyTorch domain library built to provide common
-sparsity & parallelism primitives needed for large-scale recommender
-systems (RecSys). It allows authors to train models with large
-embedding tables sharded across many GPUs.
+TorchRec is a specialized library within the PyTorch ecosystem,
+tailored for building, scaling, and deploying large-scale
+**recommendation systems**, a niche not directly addressed by standard
+PyTorch. TorchRec offers advanced features such as complex sharding
+techniques for massive embedding tables, and enhanced distributed
+training capabilities.
+
+Getting Started
+---------------
+
+Topics in this section will help you get started with TorchRec.
+
+.. grid:: 3
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ TorchRec Overview
+ :img-top: _static/img/card-background.svg
+ :link: overview.html
+ :link-type: url
+
+ A short intro to TorchRec and why you need it.
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ Set up TorchRec
+ :img-top: _static/img/card-background.svg
+ :link: setup-torchrec.html
+ :link-type: url
+
+ Learn how to install and start using TorchRec
+ in your environment.
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ Getting Started with TorchRec Tutorial
+ :img-top: _static/img/card-background.svg
+ :link: https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb
+ :link-type: url
+
+ Follow our interactive step-by-step tutorial
+ to learn how to use TorchRec in a real-life
+ example.
+
+
-For installation instructions, visit
+How to Contribute
+-----------------
-https://github.com/pytorch/torchrec#readme
+We welcome contributions and feedback from the PyTorch community!
+If you are interested in helping improve the TorchRec project, here is
+how you can contribute:
-Tutorial
---------
-In this tutorial, we introduce the primary torchRec
-API called DistributedModelParallel, or DMP.
-Like pytorch’s DistributedDataParallel,
-DMP wraps a model to enable distributed training.
+1. **Visit Our** `GitHub Repository `__:
+ There you can find the source code, issues, and ongoing projects.
-* `Tutorial Source `_
-* Open in `Google Colab `_
+1. **Submit Feedback or Issues**: If you encounter any bugs or have
+ suggestions for improvements, please submit an issue through the
+ `GitHub issue tracker `__.
-TorchRec API
-------------
+1. **Propose changes**: Fork the repository and submit pull requests.
+ Whether it's fixing a bug, adding new features, or improving
+ documentation, your contributions are always welcome! Please make sure to
+ review our `CONTRIBUTING.md `__
+
+|
+|
+
+.. container:: center-content
+
+ .. button-link:: https://github.com/pytorch/torchrec
+ :color: info
+
+ :octicon:`mark-github` Go to TorchRec Repo
+
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Introduction
+ :hidden:
+
+ overview.rst
+ high-level-arch.rst
+ concepts.rst
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Getting Started
+ :hidden:
+
+ setup-torchrec.rst
.. toctree::
:maxdepth: 2
- :caption: Contents:
-
- torchrec.datasets.rst
- torchrec.datasets.scripts.rst
- torchrec.distributed.rst
- torchrec.distributed.planner.rst
- torchrec.distributed.sharding.rst
- torchrec.fx.rst
- torchrec.inference.rst
- torchrec.models.rst
- torchrec.modules.rst
- torchrec.optim.rst
- torchrec.quant.rst
- torchrec.sparse.rst
-
-Indices and tables
-==================
-
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
+ :caption: API References
+ :hidden:
+
+ datatypes-api-reference.rst
+ modules-api-reference.rst
+ planner-api-reference.rst
+ model-parallel-api-reference.rst
+ inference-api-reference.rst
diff --git a/docs/source/inference-api-reference.rst b/docs/source/inference-api-reference.rst
new file mode 100644
index 000000000..0e10067c7
--- /dev/null
+++ b/docs/source/inference-api-reference.rst
@@ -0,0 +1,19 @@
+Inference
+----------------------------------
+
+TorchRec provides easy-to-use APIs for transforming an authored TorchRec model
+into an optimized inference model for distributed inference, via eager module swaps.
+
+This transforms TorchRec modules like ``EmbeddingBagCollection`` in the model to
+a quantized, sharded version that can be compiled using torch.fx and TorchScript
+for inference in a C++ environment.
+
+The intended use is calling ``quantize_inference_model`` on the model followed by
+``shard_quant_model``.
+
+.. codeblock::
+
+.. automodule:: torchrec.inference.modules
+
+.. autofunction:: quantize_inference_model
+.. autofunction:: shard_quant_model
diff --git a/docs/source/model-parallel-api-reference.rst b/docs/source/model-parallel-api-reference.rst
new file mode 100644
index 000000000..196fd27fb
--- /dev/null
+++ b/docs/source/model-parallel-api-reference.rst
@@ -0,0 +1,10 @@
+Model Parallel
+----------------------------------
+
+``DistributedModelParallel`` is the main API for distributed training with TorchRec optimizations.
+
+
+.. automodule:: torchrec.distributed.model_parallel
+
+.. autoclass:: DistributedModelParallel
+ :members:
diff --git a/docs/source/modules-api-reference.rst b/docs/source/modules-api-reference.rst
new file mode 100644
index 000000000..98cff7ad8
--- /dev/null
+++ b/docs/source/modules-api-reference.rst
@@ -0,0 +1,30 @@
+Modules
+----------------------------------
+
+Standard TorchRec modules represent collections of embedding tables:
+
+* ``EmbeddingBagCollection`` is a collection of ``torch.nn.EmbeddingBag``
+* ``EmbeddingCollection`` is a collection of ``torch.nn.Embedding``
+
+These modules are constructed through standardized config classes:
+
+* ``EmbeddingBagConfig`` for ``EmbeddingBagCollection``
+* ``EmbeddingConfig`` for ``EmbeddingCollection``
+
+.. automodule:: torchrec.modules.embedding_configs
+
+.. autoclass:: EmbeddingBagConfig
+ :show-inheritance:
+
+.. autoclass:: EmbeddingConfig
+ :show-inheritance:
+
+.. autoclass:: BaseEmbeddingConfig
+
+.. automodule:: torchrec.modules.embedding_modules
+
+.. autoclass:: EmbeddingBagCollection
+ :members:
+
+.. autoclass:: EmbeddingCollection
+ :members:
diff --git a/docs/source/overview.rst b/docs/source/overview.rst
new file mode 100644
index 000000000..0098c14df
--- /dev/null
+++ b/docs/source/overview.rst
@@ -0,0 +1,23 @@
+.. _overview_label:
+
+==================
+TorchRec Overview
+==================
+
+TorchRec is the PyTorch recommendation system library, designed to provide common primitives
+for creating state-of-the-art personalization models and a path to production. TorchRec is
+widely adopted in many Meta production recommendation system models for training and inference workflows.
+
+Why TorchRec?
+------------------
+
+TorchRec is designed to address the unique challenges of building, scaling and deploying massive,
+large-scale recommendation system models, which is not a focus of regular PyTorch. More specifically,
+TorchRec provides the following primitives for general recommendation systems:
+
+- **Specialized Components**: TorchRec provides simplistic, specialized modules that are common in authoring recommendation systems, with a focus on embedding tables
+- **Advanced Sharding Techniques**: TorchRec provides flexible and customizable methods for sharding massive embedding tables: Row-Wise, Column-Wise, Table-Wise, and so on. TorchRec can automatically determine the best plan for a device topology for efficient training and memory balance
+- **Distributed Training**: While PyTorch supports basic distributed training, TorchRec extends these capabilities with more sophisticated model parallelism techniques specifically designed for the massive scale of recommendation systems
+- **Incredibly Optimized**: TorchRec training and inference components are incredibly optimized on top of FBGEMM. After all, TorchRec powers some of the largest recommendation system models at Meta
+- **Frictionless Path to Deployment**: TorchRec provides simple APIs for transforming a trained model for inference and loading it into a C++ environment for the most optimal inference model
+- **Integration with PyTorch Ecosystem**: TorchRec is built on top of PyTorch, meaning it integrates seamlessly with existing PyTorch code, tools, and workflows. This allows developers to leverage their existing knowledge and codebase while utilizing advanced features for recommendation systems. By being a part of the PyTorch ecosystem, TorchRec benefits from the robust community support, continuous updates, and improvements that come with PyTorch.
diff --git a/docs/source/planner-api-reference.rst b/docs/source/planner-api-reference.rst
new file mode 100644
index 000000000..e6cc7f4d2
--- /dev/null
+++ b/docs/source/planner-api-reference.rst
@@ -0,0 +1,50 @@
+Planner
+----------------------------------
+
+The TorchRec Planner is responsible for determining the most performant, balanced
+sharding plan for distributed training and inference.
+
+The main API for generating a sharding plan is ``EmbeddingShardingPlanner.plan``
+
+.. automodule:: torchrec.distributed.types
+
+.. autoclass:: ShardingPlan
+ :members:
+
+.. automodule:: torchrec.distributed.planner.planners
+
+.. autoclass:: EmbeddingShardingPlanner
+ :members:
+
+.. automodule:: torchrec.distributed.planner.enumerators
+
+.. autoclass:: EmbeddingEnumerator
+ :members:
+
+.. automodule:: torchrec.distributed.planner.partitioners
+
+.. autoclass:: GreedyPerfPartitioner
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.storage_reservations
+
+.. autoclass:: HeuristicalStorageReservation
+ :members:
+
+.. automodule:: torchrec.distributed.planner.proposers
+
+.. autoclass:: GreedyProposer
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.shard_estimators
+
+.. autoclass:: EmbeddingPerfEstimator
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.shard_estimators
+
+.. autoclass:: EmbeddingStorageEstimator
+ :members:
diff --git a/docs/source/setup-torchrec.rst b/docs/source/setup-torchrec.rst
new file mode 100644
index 000000000..7a2cbf969
--- /dev/null
+++ b/docs/source/setup-torchrec.rst
@@ -0,0 +1,139 @@
+
+===================
+Setting up TorchRec
+===================
+
+In this section, we will:
+
+* Understand requirements for using TorchRec
+* Set up an environment to integrate TorchRec
+* Run basic TorchRec code
+
+
+System Requirements
+-------------------
+
+TorchRec is routinely tested on AWS Linux only and should work in similar environments.
+Below demonstrates the compatability matrix that is currently tested:
+
+.. list-table::
+ :widths: 25 75
+ :header-rows: 0
+
+ * - Python Version
+ - 3.9, 3.10, 3.11, 3.12
+ * - Compute Platform
+ - CPU, CUDA 11.8, CUDA 12.1, CUDA 12.4
+
+Aside from those requirements, TorchRec's core dependencies are PyTorch and FBGEMM.
+If your system is compatible with both libraries generally, then it should be sufficient for TorchRec.
+
+* `PyTorch requirements `_
+* `FBGEMM requirements `_
+
+
+Version Compatability
+---------------------
+
+TorchRec and FBGEMM have matching version numbers that are tested together upon release:
+
+* TorchRec 1.0 is compatible with FBGEMM 1.0
+* TorchRec 0.8 is compatible with FBGEMM 0.8
+* TorchRec 0.8 may not be compatible with FBGEMM 0.7
+
+Furthermore, TorchRec and FBGEMM are released only when a new PyTorch release happens.
+Therefore, specific versions of TorchRec and FBGEMM should correspond to a specific PyTorch version:
+
+* TorchRec 1.0 is compatible with PyTorch 2.5
+* TorchRec 0.8 is compatible with PyTorch 2.4
+* TorchRec 0.8 may not be compatible with PyTorch 2.3
+
+Installation
+------------
+Below we show installations for CUDA 12.1 as an example. For CPU, CUDA 11.8, or CUDA 12.4, swap ``cu121`` for ``cpu``, ``cu118``, or ``cu124`` respectively.
+
+.. tab-set::
+
+ .. tab-item:: **Stable via pytorch.org**
+
+ .. code-block:: bash
+
+ pip install torch --index-url https://download.pytorch.org/whl/cu121
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
+ pip install torchmetrics==1.0.3
+ pip install torchrec --index-url https://download.pytorch.org/whl/cu121
+
+ .. tab-item:: **Stable via PyPI (Only for CUDA 12.4)**
+
+ .. code-block:: bash
+
+ pip install torch
+ pip install fbgemm-gpu
+ pip install torchrec
+
+ .. tab-item:: **Nightly**
+
+ .. code-block:: bash
+
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121
+ pip install torchmetrics==1.0.3
+ pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121
+
+ .. tab-item:: **Building From Source**
+
+ You also have the ability to build TorchRec from source to develop with the latest
+ changes in TorchRec. To build from source, check out this `reference `_.
+
+
+Run a Simple TorchRec Example
+------------------------------
+Now that we have TorchRec properly set up, let's run some TorchRec code!
+Below, we'll run a simple forward pass with TorchRec data types: ``KeyedJaggedTensor`` and ``EmbeddingBagCollection``:
+
+.. code-block:: python
+
+ import torch
+
+ import torchrec
+ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
+
+ ebc = torchrec.EmbeddingBagCollection(
+ device="cpu",
+ tables=[
+ torchrec.EmbeddingBagConfig(
+ name="product_table",
+ embedding_dim=16,
+ num_embeddings=4096,
+ feature_names=["product"],
+ pooling=torchrec.PoolingType.SUM,
+ ),
+ torchrec.EmbeddingBagConfig(
+ name="user_table",
+ embedding_dim=16,
+ num_embeddings=4096,
+ feature_names=["user"],
+ pooling=torchrec.PoolingType.SUM,
+ )
+ ]
+ )
+
+ product_jt = JaggedTensor(
+ values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
+ )
+ user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
+
+ # Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
+ kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
+
+ print("Call EmbeddingBagCollection Forward: ", ebc(kjt))
+
+Save the above code to a file named ``torchrec_example.py``. Then, you should be able to
+execute it from your terminal with:
+
+.. code-block:: bash
+
+ python torchrec_example.py
+
+You should see the output ``KeyedTensor`` with the resulting embeddings.
+Congrats! You have correctly installed and ran your first TorchRec program!
diff --git a/docs/source/torchrec.datasets.rst b/docs/source/torchrec.datasets.rst
deleted file mode 100644
index a16be3c4c..000000000
--- a/docs/source/torchrec.datasets.rst
+++ /dev/null
@@ -1,36 +0,0 @@
-torchrec.datasets
-=================
-
-.. automodule:: torchrec.datasets
-
-torchrec.datasets.criteo
-------------------------
-
-.. automodule:: torchrec.datasets.criteo
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.movielens
----------------------------
-
-.. automodule:: torchrec.datasets.movielens
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.random
-------------------------
-
-.. automodule:: torchrec.datasets.random
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.utils
------------------------
-
-.. automodule:: torchrec.datasets.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.datasets.scripts.rst b/docs/source/torchrec.datasets.scripts.rst
deleted file mode 100644
index 947250d67..000000000
--- a/docs/source/torchrec.datasets.scripts.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-torchrec.datasets.scripts
-=========================
-
-.. automodule:: torchrec.datasets.scripts
-
-
-
-torchrec.datasets.scripts.contiguous\_preproc\_criteo
------------------------------------------------------
-
-.. automodule:: torchrec.datasets.scripts.contiguous_preproc_criteo
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.scripts.npy\_preproc\_criteo
-----------------------------------------------
-
-.. automodule:: torchrec.datasets.scripts.npy_preproc_criteo
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.planner.rst b/docs/source/torchrec.distributed.planner.rst
deleted file mode 100644
index 82200125e..000000000
--- a/docs/source/torchrec.distributed.planner.rst
+++ /dev/null
@@ -1,93 +0,0 @@
-torchrec.distributed.planner
-============================
-
-.. automodule:: torchrec.distributed.planner
-
-
-torchrec.distributed.planner.constants
---------------------------------------
-
-.. automodule:: torchrec.distributed.planner.constants
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.enumerators
-----------------------------------------
-
-.. automodule:: torchrec.distributed.planner.enumerators
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.partitioners
------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.partitioners
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.perf\_models
------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.perf_models
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.planners
--------------------------------------
-
-.. automodule:: torchrec.distributed.planner.planners
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.proposers
---------------------------------------
-
-.. automodule:: torchrec.distributed.planner.proposers
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.shard\_estimators
-----------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.shard_estimators
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.stats
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.stats
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.storage\_reservations
---------------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.storage_reservations
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.types
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.utils
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.rst b/docs/source/torchrec.distributed.rst
deleted file mode 100644
index 9fdcb4282..000000000
--- a/docs/source/torchrec.distributed.rst
+++ /dev/null
@@ -1,126 +0,0 @@
-torchrec.distributed
-====================
-
-.. automodule:: torchrec.distributed
-
-
-torchrec.distributed.collective\_utils
---------------------------------------
-
-.. automodule:: torchrec.distributed.collective_utils
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.comm
--------------------------
-
-.. automodule:: torchrec.distributed.comm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.comm\_ops
-------------------------------
-
-.. automodule:: torchrec.distributed.comm_ops
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.dist\_data
--------------------------------
-
-.. automodule:: torchrec.distributed.dist_data
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding
-------------------------------
-
-.. automodule:: torchrec.distributed.embedding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_lookup
---------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_lookup
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_sharding
-----------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_types
--------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embeddingbag
----------------------------------
-
-.. automodule:: torchrec.distributed.embeddingbag
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.grouped\_position\_weighted
-------------------------------------------------
-
-.. automodule:: torchrec.distributed.grouped_position_weighted
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.model\_parallel
-------------------------------------
-
-.. automodule:: torchrec.distributed.model_parallel
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.quant\_embeddingbag
-----------------------------------------
-
-.. automodule:: torchrec.distributed.quant_embeddingbag
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.train\_pipeline
-------------------------------------
-
-.. automodule:: torchrec.distributed.train_pipeline
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.types
---------------------------
-
-.. automodule:: torchrec.distributed.types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.utils
---------------------------
-
-.. automodule:: torchrec.distributed.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.sharding.rst b/docs/source/torchrec.distributed.sharding.rst
deleted file mode 100644
index d0e7162ae..000000000
--- a/docs/source/torchrec.distributed.sharding.rst
+++ /dev/null
@@ -1,63 +0,0 @@
-torchrec.distributed.sharding
-=============================
-
-.. automodule:: torchrec.distributed.sharding
-
-
-torchrec.distributed.sharding.cw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.cw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.dist\_data
--------------------------------
-
-.. automodule:: torchrec.distributed.dist_data
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.dp\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.dp_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.sharding.rw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.rw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.sharding.tw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.tw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.twcw\_sharding
---------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.twcw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.twrw\_sharding
---------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.twrw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.fx.rst b/docs/source/torchrec.fx.rst
deleted file mode 100644
index b06656fbc..000000000
--- a/docs/source/torchrec.fx.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.fx
-===========
-
-.. automodule:: torchrec.fx
-
-torchrec.fx.tracer
-------------------
-
-.. automodule:: torchrec.fx.tracer
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.fx
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.inference.rst b/docs/source/torchrec.inference.rst
deleted file mode 100644
index 846a520ec..000000000
--- a/docs/source/torchrec.inference.rst
+++ /dev/null
@@ -1,28 +0,0 @@
-torchrec.inference
-===========
-
-.. automodule:: torchrec.inference
-
-torchrec.inference.model_packager
-------------------
-
-.. automodule:: torchrec.inference.model_packager
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.inference.modules
-------------------
-
-.. automodule:: torchrec.inference.modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.inference
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.models.rst b/docs/source/torchrec.models.rst
deleted file mode 100644
index 4d3ad3040..000000000
--- a/docs/source/torchrec.models.rst
+++ /dev/null
@@ -1,28 +0,0 @@
-torchrec.models
-===============
-
-.. automodule:: torchrec.models
-
-torchrec.models.deepfm
-----------------------
-
-.. automodule:: torchrec.models.deepfm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.models.dlrm
---------------------
-
-.. automodule:: torchrec.models.dlrm
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.models
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.modules.rst b/docs/source/torchrec.modules.rst
deleted file mode 100644
index 95dcf35db..000000000
--- a/docs/source/torchrec.modules.rst
+++ /dev/null
@@ -1,85 +0,0 @@
-torchrec.modules
-================
-
-.. automodule:: torchrec.modules
-
-torchrec.modules.activation
----------------------------
-
-.. automodule:: torchrec.modules.activation
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.crossnet
--------------------------
-
-.. automodule:: torchrec.modules.crossnet
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.deepfm
------------------------
-
-.. automodule:: torchrec.modules.deepfm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.embedding\_configs
------------------------------------
-
-.. automodule:: torchrec.modules.embedding_configs
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.embedding\_modules
------------------------------------
-
-.. automodule:: torchrec.modules.embedding_modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.feature\_processor
------------------------------------
-
-.. automodule:: torchrec.modules.feature_processor
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.lazy\_extension
---------------------------------
-
-.. automodule:: torchrec.modules.lazy_extension
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.mlp
---------------------
-
-.. automodule:: torchrec.modules.mlp
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.modules.utils
-----------------------
-
-.. automodule:: torchrec.modules.utils
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.modules
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.optim.rst b/docs/source/torchrec.optim.rst
deleted file mode 100644
index df3a147d0..000000000
--- a/docs/source/torchrec.optim.rst
+++ /dev/null
@@ -1,44 +0,0 @@
-torchrec.optim
-==============
-
-.. automodule:: torchrec.optim
-
-torchrec.optim.clipping
------------------------
-
-.. automodule:: torchrec.optim.clipping
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.fused
---------------------
-
-.. automodule:: torchrec.optim.fused
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.keyed
---------------------
-
-.. automodule:: torchrec.optim.keyed
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.warmup
----------------------
-
-.. automodule:: torchrec.optim.warmup
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.optim
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.quant.rst b/docs/source/torchrec.quant.rst
deleted file mode 100644
index df4b34bd4..000000000
--- a/docs/source/torchrec.quant.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.quant
-==============
-
-.. automodule:: torchrec.quant
-
-torchrec.quant.embedding\_modules
----------------------------------
-
-.. automodule:: torchrec.quant.embedding_modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.quant
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.sparse.rst b/docs/source/torchrec.sparse.rst
deleted file mode 100644
index efbd3eaec..000000000
--- a/docs/source/torchrec.sparse.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.sparse
-===============
-
-.. automodule:: torchrec.sparse
-
-torchrec.sparse.jagged\_tensor
-------------------------------
-
-.. automodule:: torchrec.sparse.jagged_tensor
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.sparse
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/examples/bert4rec/bert4rec_main.py b/examples/bert4rec/bert4rec_main.py
index 1f3f1319b..f5c8ccff5 100644
--- a/examples/bert4rec/bert4rec_main.py
+++ b/examples/bert4rec/bert4rec_main.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
@@ -26,6 +28,7 @@
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ModuleSharder
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from tqdm import tqdm
@@ -497,11 +500,12 @@ def main(argv: List[str]) -> None:
],
)
dense_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: optim.Adam(
params, lr=args.lr, weight_decay=args.weight_decay
),
)
+
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
else:
device_ids = [rank] if backend == "nccl" else None
diff --git a/examples/bert4rec/bert4rec_metrics.py b/examples/bert4rec/bert4rec_metrics.py
index 09140d62e..6ad244fdc 100644
--- a/examples/bert4rec/bert4rec_metrics.py
+++ b/examples/bert4rec/bert4rec_metrics.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Dict, List
import torch
diff --git a/examples/bert4rec/data/bert4rec_movielens_datasets.py b/examples/bert4rec/data/bert4rec_movielens_datasets.py
index fb912126d..18144cc48 100644
--- a/examples/bert4rec/data/bert4rec_movielens_datasets.py
+++ b/examples/bert4rec/data/bert4rec_movielens_datasets.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import random
from collections import Counter
from pathlib import Path
diff --git a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
index daca1c4e7..06c8deef8 100644
--- a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
+++ b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
from ..bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe
diff --git a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
index af7a3b22f..a9303d218 100644
--- a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
+++ b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Any, Dict, Tuple
import pandas as pd
diff --git a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
index 0a5526af3..a47c0fe7d 100644
--- a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
+++ b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
from ...data.bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe
diff --git a/examples/bert4rec/models/bert4rec.py b/examples/bert4rec/models/bert4rec.py
index 0ca8b83b0..b90ac4305 100644
--- a/examples/bert4rec/models/bert4rec.py
+++ b/examples/bert4rec/models/bert4rec.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import copy
import math
diff --git a/examples/bert4rec/models/tests/test_bert4rec.py b/examples/bert4rec/models/tests/test_bert4rec.py
index ba46c299d..5149f188b 100644
--- a/examples/bert4rec/models/tests/test_bert4rec.py
+++ b/examples/bert4rec/models/tests/test_bert4rec.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import unittest
diff --git a/examples/bert4rec/tests/test_bert4rec_main.py b/examples/bert4rec/tests/test_bert4rec_main.py
index d53198c6e..705f4b64c 100644
--- a/examples/bert4rec/tests/test_bert4rec_main.py
+++ b/examples/bert4rec/tests/test_bert4rec_main.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/datasets/README b/examples/datasets/README
deleted file mode 100644
index 74210008a..000000000
--- a/examples/datasets/README
+++ /dev/null
@@ -1 +0,0 @@
-Datasets under this directory are prototyping with libaries under active development (e.g. TorchArrow DataFrame) which may not have best performance yet, and subject to change.
diff --git a/examples/datasets/criteo_dataframes.py b/examples/datasets/criteo_dataframes.py
deleted file mode 100644
index dec708602..000000000
--- a/examples/datasets/criteo_dataframes.py
+++ /dev/null
@@ -1,101 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Iterable, List, Tuple, Union
-
-import numpy as np
-import torch.utils.data.datapipes as dp
-import torcharrow as ta
-import torcharrow.dtypes as dt
-from torch.utils.data import IterDataPipe
-from torchrec.datasets.criteo import (
- CAT_FEATURE_COUNT,
- CriteoIterDataPipe,
- DEFAULT_CAT_NAMES,
- DEFAULT_INT_NAMES,
- INT_FEATURE_COUNT,
-)
-from torchrec.datasets.utils import safe_cast
-
-DTYPE = dt.Struct(
- [
- dt.Field("labels", dt.int8),
- dt.Field(
- "dense_features",
- dt.Struct(
- [
- dt.Field(int_name, dt.Int32(nullable=True))
- for int_name in DEFAULT_INT_NAMES
- ]
- ),
- ),
- dt.Field(
- "sparse_features",
- dt.Struct(
- [
- dt.Field(cat_name, dt.Int32(nullable=True))
- for cat_name in DEFAULT_CAT_NAMES
- ]
- ),
- ),
- ]
-)
-
-
-def _torcharrow_row_mapper(
- row: List[str],
-) -> Tuple[int, Tuple[int, ...], Tuple[int, ...]]:
- # TODO: Fix safe_cast type annotation
- label = int(safe_cast(row[0], int, 0))
- dense = tuple(
- (int(safe_cast(row[i], int, 0)) for i in range(1, 1 + INT_FEATURE_COUNT))
- )
- sparse = tuple(
- (
- int(safe_cast(row[i], str, "0") or "0", 16)
- for i in range(
- 1 + INT_FEATURE_COUNT, 1 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT
- )
- )
- )
- # TorchArrow doesn't support uint32, but we can save memory
- # by not using int64. Numpy will automatically handle sparse values >= 2 ** 31.
- sparse = tuple(np.array(sparse, dtype=np.int32).tolist())
-
- return (label, dense, sparse)
-
-
-def criteo_dataframes_from_tsv(
- paths: Union[str, Iterable[str]],
- *,
- batch_size: int = 128,
-) -> IterDataPipe:
- """
- Load Criteo dataset (Kaggle or Terabyte) as TorchArrow DataFrame streams from TSV file(s)
-
- This implementaiton is inefficient and is used for prototype and test only.
-
- Args:
- paths (str or Iterable[str]): local paths to TSV files that constitute
- the Kaggle or Criteo 1TB dataset.
-
- Example::
-
- datapipe = criteo_dataframes_from_tsv(
- ["/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv"]
- )
- for df in datapipe:
- print(df)
- """
- if isinstance(paths, str):
- paths = [paths]
-
- datapipe = CriteoIterDataPipe(paths, row_mapper=_torcharrow_row_mapper)
- datapipe = dp.iter.Batcher(datapipe, batch_size)
- datapipe = dp.iter.Mapper(datapipe, lambda batch: ta.dataframe(batch, dtype=DTYPE))
-
- return datapipe
diff --git a/examples/datasets/tests/test_criteo_dataframes.py b/examples/datasets/tests/test_criteo_dataframes.py
deleted file mode 100644
index 18cd5079b..000000000
--- a/examples/datasets/tests/test_criteo_dataframes.py
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import contextlib
-
-import torcharrow as ta
-from torch.utils.data import IterDataPipe
-from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest
-
-from ..criteo_dataframes import criteo_dataframes_from_tsv
-
-
-class CriteoDataFramesTest(CriteoTest):
- BATCH_SIZE = 4
-
- def test_single_file(self) -> None:
- with self._create_dataset_tsv() as dataset_pathname:
- dataset = criteo_dataframes_from_tsv(
- dataset_pathname, batch_size=self.BATCH_SIZE
- )
-
- self._validate_dataset(dataset, 10)
-
- def test_multiple_files(self) -> None:
- with contextlib.ExitStack() as stack:
- pathnames = [
- stack.enter_context(self._create_dataset_tsv()) for _ in range(3)
- ]
- dataset = criteo_dataframes_from_tsv(pathnames, batch_size=self.BATCH_SIZE)
-
- self._validate_dataset(dataset, 30)
-
- def _validate_dataset(
- self, dataset: IterDataPipe, expected_total_length: int
- ) -> None:
- last_batch = False
- total_length = 0
-
- for df in dataset:
- self.assertFalse(last_batch)
- self.assertTrue(isinstance(df, ta.DataFrame))
- self.assertLessEqual(len(df), self.BATCH_SIZE)
-
- total_length += len(df)
- if len(df) < self.BATCH_SIZE:
- last_batch = True
-
- self._validate_dataframe(df)
- self.assertEqual(total_length, expected_total_length)
-
- def _validate_dataframe(self, df: ta.DataFrame, train: bool = True) -> None:
- if train:
- self.assertEqual(len(df.columns), 3)
- labels = df["labels"]
- for label_val in labels:
- self.assertTrue(
- self.LABEL_VAL_RANGE[0] <= label_val <= self.LABEL_VAL_RANGE[1]
- )
- else:
- self.assertEqual(len(df.columns), 2)
-
- # Validations for both train and test
- dense_features = df["dense_features"]
- for idx in range(self.INT_FEATURE_COUNT):
- int_vals = dense_features[f"int_{idx}"]
- for int_val in int_vals:
- self.assertTrue(
- self.INT_VAL_RANGE[0] <= int_val <= self.INT_VAL_RANGE[1]
- )
-
- sparse_features = df["sparse_features"]
- for idx in range(self.CAT_FEATURE_COUNT):
- cat_vals = sparse_features[f"cat_{idx}"]
- for cat_val in cat_vals:
- # stored as int32
- self.assertTrue(-(2**31) <= cat_val <= 2**31 - 1)
diff --git a/examples/golden_training/tests/test_train_dlrm.py b/examples/golden_training/tests/test_train_dlrm.py
index 691d52937..25e1cc7fc 100644
--- a/examples/golden_training/tests/test_train_dlrm.py
+++ b/examples/golden_training/tests/test_train_dlrm.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/golden_training/train_dlrm.py b/examples/golden_training/train_dlrm.py
index 511b71bce..55514c53d 100644
--- a/examples/golden_training/train_dlrm.py
+++ b/examples/golden_training/train_dlrm.py
@@ -5,12 +5,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
from typing import List, Optional
import torch
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
from torch.utils.data import IterableDataset
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
@@ -25,8 +30,8 @@
from torchrec.models.dlrm import DLRM, DLRMTrain
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
-from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from tqdm import tqdm
@@ -132,7 +137,7 @@ def train(
)
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
)
# Overlap comm/compute/device transfer during training through train_pipeline
diff --git a/examples/golden_training/train_dlrm_data_parallel.py b/examples/golden_training/train_dlrm_data_parallel.py
new file mode 100644
index 000000000..0d57f9617
--- /dev/null
+++ b/examples/golden_training/train_dlrm_data_parallel.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+import os
+from typing import Iterator, List, Optional, Tuple
+
+import torch
+from torch import distributed as dist, nn
+from torch.distributed.elastic.multiprocessing.errors import record
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
+from torch.utils.data import IterableDataset
+from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
+from torchrec.datasets.random import RandomRecDataset
+from torchrec.distributed import TrainPipelineSparseDist
+from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
+from torchrec.distributed.fbgemm_qcomm_codec import (
+ CommType,
+ get_qcomm_codecs_registry,
+ QCommsConfig,
+)
+from torchrec.distributed.model_parallel import DistributedModelParallel
+from torchrec.distributed.planner import EmbeddingShardingPlanner
+from torchrec.distributed.planner.types import ParameterConstraints, ShardingPlan
+from torchrec.models.dlrm import DLRM, DLRMTrain
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
+from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
+from tqdm import tqdm
+
+
+def _get_random_dataset(
+ num_embeddings: int,
+ batch_size: int = 32,
+) -> IterableDataset:
+ return RandomRecDataset(
+ keys=DEFAULT_CAT_NAMES,
+ batch_size=batch_size,
+ hash_size=num_embeddings,
+ ids_per_feature=1,
+ num_dense=len(DEFAULT_INT_NAMES),
+ )
+
+
+@record
+def main() -> None:
+ train()
+
+
+def train(
+ num_embeddings: int = 1024**2,
+ embedding_dim: int = 128,
+ dense_arch_layer_sizes: Optional[List[int]] = None,
+ over_arch_layer_sizes: Optional[List[int]] = None,
+ learning_rate: float = 0.1,
+ num_iterations: int = 1000,
+ qcomm_forward_precision: Optional[CommType] = CommType.FP16,
+ qcomm_backward_precision: Optional[CommType] = CommType.BF16,
+) -> None:
+ """
+ Duplicate of train_dlrm.py, but manually forces one table to be data_parallel.
+ We then optimize this table with RWAdagrad, and optimize the rest of the dense params (i.e dense, inter and over archs) with vanilla Adagrad.
+
+ Constructs and trains a DLRM model (using random dummy data). Each script is run on each process (rank) in SPMD fashion.
+ The embedding layers will be sharded across available ranks
+
+ qcomm_forward_precision: Compression used in forwards pass. FP16 is the recommended usage. INT8 and FP8 are in development, but feel free to try them out.
+ qcomm_backward_precision: Compression used in backwards pass. We recommend using BF16 to ensure training stability.
+
+ The effects of quantized comms will be most apparent in large training jobs across multiple nodes where inter host communication is expensive.
+ """
+ if dense_arch_layer_sizes is None:
+ dense_arch_layer_sizes = [64, embedding_dim]
+ if over_arch_layer_sizes is None:
+ over_arch_layer_sizes = [64, 1]
+
+ # Init process_group , device, rank, backend
+ rank = int(os.environ["LOCAL_RANK"])
+ if torch.cuda.is_available():
+ device: torch.device = torch.device(f"cuda:{rank}")
+ backend = "nccl"
+ torch.cuda.set_device(device)
+ else:
+ device: torch.device = torch.device("cpu")
+ backend = "gloo"
+ dist.init_process_group(backend=backend)
+
+ # Construct DLRM module
+ eb_configs = [
+ EmbeddingBagConfig(
+ name=f"t_{feature_name}",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_embeddings,
+ feature_names=[feature_name],
+ )
+ for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
+ ]
+ dlrm_model = DLRM(
+ embedding_bag_collection=EmbeddingBagCollection(
+ tables=eb_configs, device=torch.device("meta")
+ ),
+ dense_in_features=len(DEFAULT_INT_NAMES),
+ dense_arch_layer_sizes=dense_arch_layer_sizes,
+ over_arch_layer_sizes=over_arch_layer_sizes,
+ dense_device=device,
+ )
+ train_model = DLRMTrain(dlrm_model)
+ # Optional: force some tables (table 10) to be data_parallel through planner ParameterConstraints
+ planner = EmbeddingShardingPlanner(
+ constraints={
+ "t_cat_10": ParameterConstraints(compute_kernels=["dense"]),
+ }
+ )
+ plan: ShardingPlan = planner.collective_plan(train_model)
+
+ apply_optimizer_in_backward(
+ RowWiseAdagrad,
+ train_model.model.sparse_arch.parameters(),
+ {"lr": learning_rate},
+ )
+ qcomm_codecs_registry = (
+ get_qcomm_codecs_registry(
+ qcomms_config=QCommsConfig(
+ # pyre-ignore
+ forward_precision=qcomm_forward_precision,
+ # pyre-ignore
+ backward_precision=qcomm_backward_precision,
+ )
+ )
+ if backend == "nccl"
+ else None
+ )
+ sharder = EmbeddingBagCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry)
+
+ model = DistributedModelParallel(
+ plan=plan,
+ module=train_model,
+ device=device,
+ # pyre-ignore
+ sharders=[sharder],
+ )
+
+ # non fused (dense) embeddings are data parallel and use RowWiseAdagrad optimizer
+ non_fused_embedding_optimizer = KeyedOptimizerWrapper(
+ dict(
+ in_backward_optimizer_filter(
+ model.module.model.sparse_arch.named_parameters()
+ )
+ ),
+ lambda params: RowWiseAdagrad(params, lr=learning_rate),
+ )
+
+ def dense_filter(
+ named_parameters: Iterator[Tuple[str, nn.Parameter]]
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
+ for fqn, param in named_parameters:
+ if "sparse" not in fqn:
+ yield fqn, param
+
+ # DLRM dense (over, inter and dense archs) are optimized with vanilla Adagrad
+ dense_optimizer = KeyedOptimizerWrapper(
+ dict(dense_filter(model.named_parameters())),
+ lambda params: torch.optim.Adagrad(params, lr=learning_rate),
+ )
+ combined_optimizer = CombinedOptimizer(
+ [
+ ("non_fused_embedding_optim", non_fused_embedding_optimizer),
+ ("dense_optim", dense_optimizer),
+ ("fused_embedding_optim", model.fused_optimizer),
+ ]
+ )
+ # Overlap comm/compute/device transfer during training through train_pipeline
+ train_pipeline = TrainPipelineSparseDist(
+ model,
+ combined_optimizer,
+ device,
+ )
+
+ # train model
+ train_iterator = iter(
+ _get_random_dataset(
+ num_embeddings=num_embeddings,
+ )
+ )
+ for _ in tqdm(range(int(num_iterations)), mininterval=5.0):
+ train_pipeline.progress(train_iterator)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md
similarity index 100%
rename from examples/inference/README.md
rename to examples/inference_legacy/README.md
diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py
similarity index 100%
rename from examples/inference/dlrm_client.py
rename to examples/inference_legacy/dlrm_client.py
diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py
similarity index 100%
rename from examples/inference/dlrm_packager.py
rename to examples/inference_legacy/dlrm_packager.py
diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py
similarity index 96%
rename from examples/inference/dlrm_predict.py
rename to examples/inference_legacy/dlrm_predict.py
index c36bb613d..e2bdcef9d 100644
--- a/examples/inference/dlrm_predict.py
+++ b/examples/inference_legacy/dlrm_predict.py
@@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
- num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
- if self.model_config.num_embeddings is None
- else self.model_config.num_embeddings,
+ num_embeddings=(
+ self.model_config.num_embeddings_per_feature[feature_idx]
+ if self.model_config.num_embeddings is None
+ else self.model_config.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py
similarity index 94%
rename from examples/inference/dlrm_predict_single_gpu.py
rename to examples/inference_legacy/dlrm_predict_single_gpu.py
index 753cb0dcc..ba5323247 100644
--- a/examples/inference/dlrm_predict_single_gpu.py
+++ b/examples/inference_legacy/dlrm_predict_single_gpu.py
@@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
- num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
- if self.model_config.num_embeddings is None
- else self.model_config.num_embeddings,
+ num_embeddings=(
+ self.model_config.num_embeddings_per_feature[feature_idx]
+ if self.model_config.num_embeddings is None
+ else self.model_config.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
diff --git a/examples/nvt_dataloader/nvt_binary_dataloader.py b/examples/nvt_dataloader/nvt_binary_dataloader.py
index 2286b5610..f88ff1feb 100644
--- a/examples/nvt_dataloader/nvt_binary_dataloader.py
+++ b/examples/nvt_dataloader/nvt_binary_dataloader.py
@@ -94,7 +94,8 @@ def __getitem__(self, idx: int):
"""Numerical features are returned in the order they appear in the channel spec section
For performance reasons, this is required to be the order they are saved in, as specified
by the relevant chunk in source spec.
- Categorical features are returned in the order they appear in the channel spec section"""
+ Categorical features are returned in the order they appear in the channel spec section
+ """
if idx >= self._num_entries:
raise IndexError()
diff --git a/examples/nvt_dataloader/train_torchrec.py b/examples/nvt_dataloader/train_torchrec.py
index ddf53c89b..abdf3a67d 100644
--- a/examples/nvt_dataloader/train_torchrec.py
+++ b/examples/nvt_dataloader/train_torchrec.py
@@ -19,8 +19,6 @@
import torchrec
import torchrec.distributed as trec_dist
import torchrec.optim as trec_optim
-
-from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
from nvt_binary_dataloader import NvtBinaryDataloader
from pyre_extensions import none_throws
from torchrec import EmbeddingBagCollection
@@ -40,6 +38,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
def parse_args(argv: List[str]) -> argparse.Namespace:
@@ -209,9 +208,11 @@ def main(argv: List[str]):
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
- num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
- if num_embeddings_per_feature is not None
- else args.num_embeddings,
+ num_embeddings=(
+ none_throws(num_embeddings_per_feature)[feature_idx]
+ if num_embeddings_per_feature is not None
+ else args.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
@@ -233,9 +234,9 @@ def main(argv: List[str]):
train_model = fuse_embedding_optimizer(
train_model,
- optimizer_type=torchrec.optim.RowWiseAdagrad
- if args.adagrad
- else torch.optim.SGD,
+ optimizer_type=(
+ torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD
+ ),
optimizer_kwargs={"learning_rate": args.learning_rate},
device=torch.device("meta"),
)
@@ -270,10 +271,12 @@ def main(argv: List[str]):
)
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
- lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
- if args.adagrad
- else torch.optim.SGD(params, lr=args.learning_rate),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
+ lambda params: (
+ torch.optim.Adagrad(params, lr=args.learning_rate)
+ if args.adagrad
+ else torch.optim.SGD(params, lr=args.learning_rate)
+ ),
)
opt = trec_optim.keyed.CombinedOptimizer(
diff --git a/examples/ray/train_torchrec.py b/examples/ray/train_torchrec.py
index 2406c6796..4ee5fdeb9 100644
--- a/examples/ray/train_torchrec.py
+++ b/examples/ray/train_torchrec.py
@@ -24,6 +24,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from tqdm import tqdm
@@ -110,7 +111,7 @@ def train(
# Overlap comm/compute/device transfer during training through train_pipeline
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
)
train_pipeline = TrainPipelineSparseDist(
diff --git a/examples/retrieval/data/dataloader.py b/examples/retrieval/data/dataloader.py
index 5727910c7..39e2ca1c4 100644
--- a/examples/retrieval/data/dataloader.py
+++ b/examples/retrieval/data/dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from torch.utils.data import DataLoader
from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES
from torchrec.datasets.random import RandomRecDataset
diff --git a/examples/retrieval/knn_index.py b/examples/retrieval/knn_index.py
index a1323f1ca..9db4a6d7d 100644
--- a/examples/retrieval/knn_index.py
+++ b/examples/retrieval/knn_index.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Optional, Union
import faiss # @manual=//faiss/python:pyfaiss_gpu
diff --git a/examples/retrieval/modules/tests/test_two_tower.py b/examples/retrieval/modules/tests/test_two_tower.py
index 9ac08b10f..237ae7ee8 100644
--- a/examples/retrieval/modules/tests/test_two_tower.py
+++ b/examples/retrieval/modules/tests/test_two_tower.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import unittest
diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py
index 5ee2b22af..704e02b63 100644
--- a/examples/retrieval/modules/two_tower.py
+++ b/examples/retrieval/modules/two_tower.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Any, List, Mapping, Optional, OrderedDict, Tuple, Union
import faiss # @manual=//faiss/python:pyfaiss_gpu
@@ -71,12 +73,12 @@ def __init__(
embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
0
].embedding_dim
- self._feature_names_query: List[
- str
- ] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
- self._candidate_feature_names: List[
- str
- ] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
+ self._feature_names_query: List[str] = (
+ embedding_bag_collection.embedding_bag_configs()[0].feature_names
+ )
+ self._candidate_feature_names: List[str] = (
+ embedding_bag_collection.embedding_bag_configs()[1].feature_names
+ )
self.ebc = embedding_bag_collection
self.query_proj = MLP(
in_size=embedding_dim, layer_sizes=layer_sizes, device=device
@@ -174,6 +176,7 @@ def __init__(
layer_sizes: List[int],
k: int,
device: Optional[torch.device] = None,
+ dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim
@@ -186,10 +189,16 @@ def __init__(
self.query_ebc = query_ebc
self.candidate_ebc = candidate_ebc
self.query_proj = MLP(
- in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
+ in_size=self.embedding_dim,
+ layer_sizes=layer_sizes,
+ device=device,
+ dtype=dtype,
)
self.candidate_proj = MLP(
- in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
+ in_size=self.embedding_dim,
+ layer_sizes=layer_sizes,
+ device=device,
+ dtype=dtype,
)
self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index
self.k = k
@@ -212,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
candidates = torch.empty(
(batch_size, self.k), device=self.device, dtype=torch.int64
)
+ query_embedding = query_embedding.to(torch.float32) # required by faiss
self.faiss_index.search(query_embedding, self.k, distances, candidates)
# candidate lookup
@@ -227,5 +237,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
# return logit (dot product)
return (query_embedding * candidate_embedding).sum(dim=1).squeeze()
+ # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
+ # inconsistently.
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool) -> None:
super().load_state_dict(state_dict, strict)
diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py
index eef1ef455..77952e0fa 100644
--- a/examples/retrieval/tests/test_two_tower_retrieval.py
+++ b/examples/retrieval/tests/test_two_tower_retrieval.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
import torch
@@ -19,11 +21,12 @@ class InferTest(unittest.TestCase):
@skip_if_asan
# pyre-ignore[56]
@unittest.skipIf(
- not torch.cuda.is_available(),
- "this test requires a GPU",
+ torch.cuda.device_count() <= 1,
+ "Not enough GPUs, this test requires at least two GPUs",
)
def test_infer_function(self) -> None:
infer(
embedding_dim=16,
layer_sizes=[16],
+ world_size=2,
)
diff --git a/examples/retrieval/tests/test_two_tower_train.py b/examples/retrieval/tests/test_two_tower_train.py
index 2b78b8426..8db7684cf 100644
--- a/examples/retrieval/tests/test_two_tower_train.py
+++ b/examples/retrieval/tests/test_two_tower_train.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py
index b1b4ccb49..08a2de4b6 100644
--- a/examples/retrieval/two_tower_retrieval.py
+++ b/examples/retrieval/two_tower_retrieval.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import List, Optional
import click
@@ -18,16 +20,13 @@
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ShardingEnv, ShardingType
-from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
# OSS import
try:
- # pyre-ignore[21]
- # @manual=//torchrec/github/examples/retrieval/data:dataloader
- from data.dataloader import get_dataloader
# pyre-ignore[21]
# @manual=//torchrec/github/examples/retrieval:knn_index
@@ -78,6 +77,7 @@ def infer(
faiss_device_idx: int = 0,
batch_size: int = 32,
load_dir: Optional[str] = None,
+ world_size: int = 2,
) -> None:
"""
Loads the serialized model and FAISS index from `two_tower_train.py`.
@@ -116,6 +116,7 @@ def infer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
feature_names=[feature_name],
+ data_type=DataType.FP16,
)
ebcs.append(
EmbeddingBagCollection(
@@ -135,7 +136,7 @@ def infer(
# pyre-ignore[16]
faiss.read_index(f"{load_dir}/faiss.index"),
)
- two_tower_sd = torch.load(f"{load_dir}/model.pt")
+ two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True)
retrieval_sd = convert_TwoTower_to_TwoTowerRetrieval(
two_tower_sd,
[f"t_{two_tower_column_names[0]}"],
@@ -156,7 +157,9 @@ def infer(
index.train(embeddings)
index.add(embeddings)
- retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device)
+ retrieval_model = TwoTowerRetrieval(
+ index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
+ )
constraints = {}
for feature_name in two_tower_column_names:
@@ -166,13 +169,16 @@ def infer(
)
quant_model = trec_infer.modules.quantize_embeddings(
- retrieval_model, dtype=torch.qint8, inplace=True
+ retrieval_model,
+ dtype=torch.qint8,
+ inplace=True,
+ output_dtype=torch.float16,
)
dmp = DistributedModelParallel(
module=quant_model,
device=device,
- env=ShardingEnv.from_local(world_size=2, rank=model_device_idx),
+ env=ShardingEnv.from_local(world_size=world_size, rank=model_device_idx),
init_data_parallel=False,
)
if retrieval_sd is not None:
diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py
index 141c1ca7d..5772a9069 100644
--- a/examples/retrieval/two_tower_train.py
+++ b/examples/retrieval/two_tower_train.py
@@ -5,25 +5,26 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import os
-from typing import cast, List, Optional
+from typing import List, Optional
import click
import faiss # @manual=//faiss/python:pyfaiss_gpu
import faiss.contrib.torch_utils # @manual=//faiss/contrib:faiss_contrib_gpu
import torch
-from fbgemm_gpu.split_embedding_configs import EmbOptimType
-from torch import distributed as dist, nn
+from torch import distributed as dist
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
from torchrec import inference as trec_infer
from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES
from torchrec.distributed import TrainPipelineSparseDist
-from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
-from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
-from torchrec.distributed.types import ModuleSharder
from torchrec.inference.state_dict_transform import (
state_dict_gather,
state_dict_to_device,
@@ -31,6 +32,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -108,7 +110,6 @@ def train(
layer_sizes = [128, 64]
rank = int(os.environ["LOCAL_RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
@@ -138,36 +139,14 @@ def train(
device=device,
)
two_tower_train_task = TwoTowerTrainTask(two_tower_model)
-
- fused_params = {
- "learning_rate": learning_rate,
- "optimizer": EmbOptimType.ROWWISE_ADAGRAD,
- }
- sharders = cast(
- List[ModuleSharder[nn.Module]],
- [EmbeddingBagCollectionSharder(fused_params=fused_params)],
- )
-
- # TODO: move pg to the EmbeddingShardingPlanner (out of collective_plan) and make optional
- # TODO: make Topology optional argument to EmbeddingShardingPlanner
- # TODO: give collective_plan a default sharders
- # TODO: once this is done, move defaults out of DMP and just get from ShardingPlan (eg _sharding_map should not exist - just use the plan)
- plan = EmbeddingShardingPlanner(
- topology=Topology(
- world_size=world_size,
- compute_device=device.type,
- ),
- ).collective_plan(
- module=two_tower_model,
- sharders=sharders,
- # pyre-fixme[6]: For 3rd param expected `ProcessGroup` but got
- # `Optional[ProcessGroup]`.
- pg=dist.GroupMember.WORLD,
+ apply_optimizer_in_backward(
+ RowWiseAdagrad,
+ two_tower_train_task.two_tower.ebc.parameters(),
+ {"lr": learning_rate},
)
model = DistributedModelParallel(
module=two_tower_train_task,
device=device,
- plan=plan,
)
optimizer = KeyedOptimizerWrapper(
diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
deleted file mode 100644
index cc2613a38..000000000
--- a/examples/torcharrow/README.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Description
-
-This shows a prototype of integrating a TorchRec based training loop utilizing TorchArrow's on-the-fly preprocessing. The main motivation is to show the utilization of TorchArrow's specialized domain UDFs. Here we use `bucketize`, `firstx`, as well as `sigrid_hash` to do some last-mile preprocessing over the criteo dataset in parquet format. More recommendation domain functions can be found at [torcharrow.functional Doc](https://pytorch.org/torcharrow/beta/functional.html#recommendation-operations).
-
-These three UDFs are extensively used in Meta's RecSys preprocessing stack. Notably, these UDFs can be used to easily adjust the proprocessing script to any model changes. For example, if we wish to change the size of our embedding tables, without sigrid_hash, we would need to rerun a bulk offline preproc to ensure that all indicies are within bounds. Bucketize lets us easily convert dense features into sparse features, with flexibility of what the bucket borders are. firstx lets us easily prune sparse ids (note, that this doesn't provide any functional features, but is in the preproc script as demonstration).
-
-
-## Installations and Usage
-
-Download the criteo tsv files (see the README in the main DLRM example). Use the nvtabular script (in torchrec/datasets/scripts/nvt/) to convert the TSV files to parquet.
-
-To start, install torcharrow-nightly and torchdata:
-```
-pip install --pre torcharrow -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
-pip install torchdata
-```
-You can also build TorchArrow from source, following https://github.com/pytorch/torcharrow
-
-Usage
-
-```
-torchx run -s local_cwd dist.ddp -j 1x4 --script examples/torcharrow/run.py -- --parquet_directory /home/criteo_parquet
-```
-
-The preprocessing logic is in ```dataloader.py```
-
-## Extentions/Future work
-
-* We will eventually integrate with the up and coming DataLoader2, which will allow us to utilize a prebuilt solution to collate our dataframe batches to dense tensors, or TorchRec's KeyedJaggedTensors (rather than doing this by hand).
-* Building an easier solution/more performant to convert parquet -> IterableDataPipe[torcharrow.DataFrame] (aka ArrowDataPipe). Also currently batch sizes are not available.
-* Some functional abilities are not yet available (such as make_named_row, etc).
-* Support collation/conversion for ArrowDataPipe
-* More RecSys UDFs to come! Please let us know if you have any suggestions.
diff --git a/examples/torcharrow/dataloader.py b/examples/torcharrow/dataloader.py
deleted file mode 100644
index 3cca4666d..000000000
--- a/examples/torcharrow/dataloader.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torcharrow as ta
-import torcharrow.dtypes as dt
-import torcharrow.pytorch as tap
-from torch.utils.data import DataLoader
-from torcharrow import functional
-from torchdata.datapipes.iter import FileLister
-from torchrec.datasets.criteo import (
- DEFAULT_CAT_NAMES,
- DEFAULT_INT_NAMES,
- DEFAULT_LABEL_NAME,
-)
-
-from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
-
-
-class _JaggedTensorConversion(tap.TensorConversion):
- # pyre-fixme[14]: `to_tensor` overrides method defined in `TensorConversion`
- # inconsistently.
- def to_tensor(self, df: ta.DataFrame):
-
- kjt_keys = df.columns
- kjt_values = []
- kjt_lengths = []
- for row in df:
- for idx, _column in enumerate(df.columns):
- value = row[idx]
- kjt_values.extend(value)
- kjt_lengths.append(len(value))
- kjt = KeyedJaggedTensor.from_lengths_sync(
- keys=kjt_keys,
- values=torch.tensor(kjt_values),
- lengths=torch.tensor(kjt_lengths),
- )
- return kjt
-
-
-class _Scalar(tap.TensorConversion):
- def to_tensor(self, df: ta.DataFrame):
- labels = torch.tensor(df)
- return labels
-
-
-def get_dataloader(
- parquet_directory, world_size, rank, num_embeddings=4096, salt=0, batch_size=16
-):
- source_dp = FileLister(parquet_directory, masks="*.parquet")
- # TODO support batch_size for load_parquet_as_df.
- # TODO use OSSArrowDataPipe once it is ready
- parquet_df_dp = source_dp.load_parquet_as_df()
-
- def preproc(df, max_idx=num_embeddings, salt=salt):
-
- for feature_name in DEFAULT_INT_NAMES:
- df[feature_name] = df[feature_name].fill_null(0)
- for feature_name in DEFAULT_CAT_NAMES:
- df[feature_name] = df[feature_name].fill_null(0)
- df[feature_name] = df[feature_name].cast(dt.int64)
-
- # construct a sprase index from a dense one
- df["bucketize_int_0"] = functional.bucketize(df["int_0"], [0.5, 1.0, 1.5]).cast(
- dt.int64
- )
-
- # flatten several columns into one
- df["dense_features"] = ta.dataframe(
- {int_name: df[int_name] for int_name in DEFAULT_INT_NAMES}
- )
-
- df["dense_features"] = (df["dense_features"] + 3).log()
-
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]:
- # hash our embedding index into our embedding tables
- df[cat_name] = functional.sigrid_hash(df[cat_name], salt, max_idx)
- df[cat_name] = functional.array_constructor(df[cat_name])
- df[cat_name] = functional.firstx(df[cat_name], 1)
-
- df["sparse_features"] = ta.dataframe(
- {
- cat_name: df[cat_name]
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]
- }
- )
-
- df = df[["dense_features", "sparse_features", DEFAULT_LABEL_NAME]]
-
- return df
-
- parquet_df_dp = parquet_df_dp.map(preproc).sharding_filter()
- parquet_df_dp.apply_sharding(world_size, rank)
-
- def criteo_collate(df):
- dense_features, kjt, labels = df.to_tensor(
- {
- "dense_features": tap.rec.Dense(batch_first=True),
- "sparse_features": _JaggedTensorConversion(),
- "label": _Scalar(),
- }
- )
-
- return dense_features, kjt, labels
-
- return DataLoader(
- parquet_df_dp,
- batch_size=None,
- collate_fn=criteo_collate,
- drop_last=False,
- pin_memory=True,
- )
diff --git a/examples/torcharrow/requirements.txt b/examples/torcharrow/requirements.txt
deleted file mode 100644
index 3e3c4c950..000000000
--- a/examples/torcharrow/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-click
-torcharrow
-torchdata
diff --git a/examples/torcharrow/run.py b/examples/torcharrow/run.py
deleted file mode 100644
index 1a3dca47d..000000000
--- a/examples/torcharrow/run.py
+++ /dev/null
@@ -1,126 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-
-import click
-import dataloader as torcharrow_dataloader
-import torch
-import torch.distributed as dist
-from fbgemm_gpu.split_embedding_configs import EmbOptimType
-from torch.distributed.elastic.multiprocessing.errors import record
-from torchrec import EmbeddingBagCollection
-from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, INT_FEATURE_COUNT
-from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
-from torchrec.distributed.model_parallel import DistributedModelParallel
-from torchrec.models.dlrm import DLRM
-from torchrec.modules.embedding_configs import EmbeddingBagConfig
-from torchrec.optim.keyed import KeyedOptimizerWrapper
-
-
-@record
-@click.command()
-@click.option("--batch_size", default=256)
-@click.option("--num_embeddings", default=2048)
-@click.option("--sigrid_hash_salt", default=0)
-@click.option("--parquet_directory", default="/data/criteo_preproc")
-def main(
- batch_size,
- num_embeddings,
- sigrid_hash_salt,
- parquet_directory,
-) -> None:
- rank = int(os.environ["LOCAL_RANK"])
- if torch.cuda.is_available():
- device = torch.device(f"cuda:{rank}")
- backend = "nccl"
- torch.cuda.set_device(device)
- else:
- device = torch.device("cpu")
- backend = "gloo"
- print(
- "\033[92m"
- + f"WARNING: Running in CPU mode. cuda availablility {torch.cuda.is_available()}."
- )
-
- dist.init_process_group(backend=backend)
-
- world_size = dist.get_world_size()
-
- dataloader = torcharrow_dataloader.get_dataloader(
- parquet_directory,
- world_size,
- rank,
- batch_size=batch_size,
- num_embeddings=num_embeddings,
- salt=sigrid_hash_salt,
- )
- it = iter(dataloader)
-
- model = DLRM(
- embedding_bag_collection=EmbeddingBagCollection(
- tables=[
- EmbeddingBagConfig(
- name=f"table_{cat_name}",
- embedding_dim=64,
- num_embeddings=num_embeddings,
- feature_names=[cat_name],
- )
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]
- ],
- device=torch.device("meta"),
- ),
- dense_in_features=INT_FEATURE_COUNT,
- dense_arch_layer_sizes=[64],
- over_arch_layer_sizes=[32, 1],
- dense_device=device,
- )
-
- fused_params = {
- "learning_rate": 0.02,
- "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
- }
-
- sharded_model = DistributedModelParallel(
- module=model,
- device=device,
- sharders=[
- EmbeddingBagCollectionSharder(fused_params=fused_params),
- ],
- )
-
- optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
- lambda params: torch.optim.SGD(params, lr=0.01),
- )
-
- loss_fn = torch.nn.BCEWithLogitsLoss()
-
- print_example = dist.get_rank() == 0
- for (dense_features, kjt, labels) in it:
- if print_example:
- print("Example dense_features", dense_features)
- print("Example KJT input", kjt)
- print_example = False
-
- dense_features = dense_features.to(device)
- kjt = kjt.to(device)
- labels = labels.to(device)
-
- optimizer.zero_grad()
-
- preds = sharded_model(dense_features, kjt)
- loss = loss_fn(preds.squeeze(), labels.squeeze())
- loss.sum().backward()
-
- optimizer.step()
-
- print("\033[92m" + "DLRM run with torcharrow last-mile preprocessing finished!")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/transfer_learning/train_from_pretrained_embedding.py b/examples/transfer_learning/train_from_pretrained_embedding.py
index 6d659f67d..ac7db53b4 100644
--- a/examples/transfer_learning/train_from_pretrained_embedding.py
+++ b/examples/transfer_learning/train_from_pretrained_embedding.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import copyreg
import io
import os
@@ -222,7 +224,7 @@ def main() -> None:
dist.barrier()
-if __name__ == "__main__":
+def invoke_main() -> None:
lc = pet.LaunchConfig(
min_nodes=1,
max_nodes=1,
@@ -234,3 +236,7 @@ def main() -> None:
monitor_interval=1,
)
pet.elastic_launch(lc, entrypoint=main)()
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/install-requirements.txt b/install-requirements.txt
new file mode 100644
index 000000000..ed3c6aced
--- /dev/null
+++ b/install-requirements.txt
@@ -0,0 +1,6 @@
+fbgemm-gpu
+tensordict
+torchmetrics==1.0.3
+tqdm
+pyre-extensions
+iopath
diff --git a/requirements.txt b/requirements.txt
index daa72b622..6b17aeac6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,43 +1,19 @@
-arrow
-attrs
-certifi
-charset-normalizer
+black
cmake
-Cython
-distro
-docker
-docstring-parser
-fbgemm-gpu-nightly
-filelock
-fsspec
-hypothesis
-idna
+fbgemm-gpu
+hypothesis==6.70.1
iopath
-Jinja2
-MarkupSafe
-mypy-extensions
-ninja
numpy
-packaging
pandas
-portalocker
-pyarrow
-pyDeprecate
-pyparsing
-pyre-extensions==0.0.27
-python-dateutil
-pytz
-PyYAML
-requests
+pyre-extensions
scikit-build
-six
-sortedcontainers
-tabulate
-torchmetrics
+tensordict
+torchmetrics==1.0.3
torchx
tqdm
-typing-inspect
-typing_extensions
-urllib3
usort
-websocket-client
+parameterized
+
+# for tests
+# https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3
+expecttest
diff --git a/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md
new file mode 100644
index 000000000..09e9069fb
--- /dev/null
+++ b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md
@@ -0,0 +1,146 @@
+# RFC: Flexible Collision Free Embedding Table
+
+| Status | Published |
+| :---- | :---- |
+| Author(s) | Emma Lin, Joe Wang, Kaustubh Vartak, Dennis van der Staay, Huanyu He|
+| Updated | 04-13-2025 |
+
+## Motivation
+
+PyTorch and FBGemm utilize fixed-size continuous embedding memory to handle sparse features, employing a uniform hash function to map raw IDs to a limited index space. This approach has resulted in a high rate of hash collisions and inefficient storage.
+
+The Zero Collision Hash (ZCH) technique allows models to train individual IDs uniquely, leading to notable enhancements in model freshness and user engagement. When properly configured, we've seen improvements in freshness late stage ranking and early stage ranking models. However, the remapping-based ZCH solutions have presented several scalability challenges.
+
+## Objective
+
+This RFC proposes a new embedding table format that natively supports collision-free features, enhancing the scalability and usability of embedding tables.
+
+The approach involves utilizing an extremely large hash size (e.g., 2^63) to map raw IDs to a vast ID space, significantly reducing the likelihood of collisions during hashing.
+
+Instead of mapping IDs to a fixed table size (e.g., 1 billion rows), this format reserves memory spaces for each ID, effectively eliminating collisions and achieving native collision-free results.
+
+Notably, this design eliminates the need for a remapper to find available slots for colliding IDs, streamlining the process and improving overall efficiency, embedding scalability and usability.
+
+## Design Proposal
+
+### Bucket Aware Sharding and Resharding Algorithm
+
+To address the dynamic nature of sparse IDs and their distribution, we propose introducing a bucket concept. We will provide a large default bucket number configuration and an extremely large table size. The mapping from ID to bucket ID can be done in two ways:
+
+* Interleave-based: bucket\_id \= hash\_id % total\_bucket\_number
+ This approach is similar to the sharding solution used in MPZCH, allowing for seamless migration without requiring ID resharding.
+* Chunk-based:
+ * bucket\_size \= table\_size / total\_bucket\_number,
+ * bucket\_id \= id / bucket\_size
+
+
+Both options will be configurable.
+
+After sharding IDs into buckets, we will distribute the buckets sequentially across trainers. For example, with 1000 buckets and 100 trainers, each trainer would own 10 consecutive buckets.
+
+T1: b0-b9
+
+T2: b10-b19
+
+...
+
+When resharding is necessary, we will move buckets around instead of individual rows. For instance, reducing the number of trainers from 100 to 50 would result in each trainer owning 20 consecutive buckets.
+
+T1: b0-b19
+
+T2: b20-b39
+
+...
+
+The row count within each bucket can vary from 0 to the maximum bucket size, depending on the ID distribution. However, using a larger bucket number should lead to a more even distribution of IDs.
+
+#### Benefit
+
+The bucket number remains unchanged when scaling the model or adjusting the number of trainers, making it easier to move buckets around without introducing additional overhead to reshard every ID's new location.
+
+Resharding every ID can be an expensive operation, especially for large tables (e.g., over 1 billion rows).
+
+### Bucketized Torchrec Sharding and Input Distribution
+
+Based on the proposed sharding definition, the TorchRec sharder needs to be aware of the bucket configuration from the embedding table.
+
+Input distribution needs to take into account the bucket configuration, and then distribute input to the corresponding trainer.
+
+Here is the code [reference](https://github.com/pytorch/torchrec/blob/f36d26db4432bd7335f6df9e7e75d8643b7ffb04/torchrec/distributed/sharding/rw_sequence_sharding.py#L129C16-L129C36).
+
+### FBGemm Operators Optimization for Collision Free EmbeddingTable
+
+FBGEMM\_GPU (FBGEMM GPU Kernels Library) is highly optimized for fixed sized tables, with continuous memory space, including in HBM, UVM or CPU memory.
+However, when we apply collision free idea, there are several assumptions of FBGEMM are broken:
+
+* Table size is not fixed. It could grow over training iterations or shrink after eviction.
+* Embedding lookup input is not embedding offset anymore, so we need to maintain an explicit mapping from input to the embedding value.
+* Table size could exceed memory limitation, but actual trained id size is finite, so we cannot preserve memory based on table configuration.
+
+We’re looking for an optimized K/V FBGemm version to support flexible memory management.
+
+#### Training Operator (from [Joe Wang](mailto:wangj@meta.com))
+
+* Optimized CPU memory management with K/V format
+ * Reduce memory fragmentation
+ * efficient memory utilization
+ * Fast lookup performance
+ * Flexible eviction policy
+* Collision free LXU cache to avoid extra memory copy from CPU to UVM and UVM memory read during embedding lookup.
+ * The current LXU cache used by FBGemm could cause id collisions. When collision happens, prefetch won’t be able to load embedding value to HBM, which will fallback to UVM memory read during embedding lookup. This can impact training QPS in two ways:
+ * Introduce one extra CPU memory copy, since data needs to be copied from CPU to UVM, since the CPU embedding data in k/v format might not be accessible from the GPU card.
+ * Introduced H2D data copy in embedding lookup.
+
+[Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py) is the code reference of a k/v format SSD offloading operator, and provides a backend interface to hook up other k/v store implementations.
+We propose to implement a new k/v store backend, to decouple SSD and rocksDB dependency, but the SSD backend operator can be used for extra large embeddings which do not fit into host memory.
+
+#### Inference Operator
+
+On top of training operators functionality, the inference operator needs to support dequantization from nbit int value after embedding is queried out from the embedding store. We’d like to have a fast inference operator with additional requirements:
+
+* Optimized CPU memory management with k/v format
+* Collision free LXU cache
+* Keep the fast nbit Int data format support, with pooling, dequantization features.
+* Support decoupled large tensor loading and reset, to allow model state in-place update.
+ [Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py) is the current inference operator which only supports offset based access for now.
+
+### Enhancing TorchRec's Prefetch Pipeline for Synchronized Training
+
+TorchRec offers multiple training [pipelines](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py) that help overlap communication and computation, reducing embedding lookup latency and improving training QPS. Specifically, PrefetchTrainPipelineSparseDist supports synchronized training, while TrainPipelineSemiSync supports asynchronous training.
+
+Our goal is to further enhance the prefetch pipeline to enable prefetching multiple training data batches ahead of time while maintaining synchronized training and avoiding step misalignment.
+
+**Design Principles:**
+
+* Zero Collision Cache: Enable zero collision cache in GPU to cache embeddings for multiple batches without collision-induced cache misses.
+* Forward Pass: Perform embedding lookups only in HBM during forward passes to improve performance.
+* Backward Pass: Update embeddings in HBM synchronously during backward passes to ensure all embedding lookup results are up-to-date.
+* Asynchronous UVM Embedding Update: Update UVM embeddings asynchronously after embedding updates in HBM.
+
+**End Goal:**
+
+Achieve on-par performance with GPU HBM-based training while scaling up sparse embedding tables in CPU memory.
+
+### Warm Start and Transfer Learning with Collision-Free Embedding Tables
+
+Warm start, or transfer learning, is a widely used technique in industry to facilitate model iteration while maintaining on-par topline metrics. However, the introduction of collision-free embedding tables poses challenges to this process.
+
+With the ability to increase table size and feature hash size to \~2^60, collision-free embedding tables offer improved efficiency. However, since id hash size is changed and sharding solution is different, when resuming training from a non-zero collision table to a zero-collision table, the redistribution of IDs across trainers becomes computationally expensive.
+
+#### Solution: Backfilling Embedding Values
+
+To address this challenge, we propose the following solution:
+
+* Create Two Embedding Tables: One table is copied from the old model, and the other is the new table.
+
+* Freeze Old Embedding Table: The old embedding table is set to read-only mode in the new model.
+
+* Training Forward Loop: During the forward pass, if an embedding value is not found in the new table, the underlying embedding lookup operator searches the old embedding table for a pre-trained value.
+
+ * This requires an additional all-to-all call using TorchRec to retrieve the old embedding value.
+
+ * We need to leverage the prefetch process to hide the extra latency.
+
+* Stop Backfilling Process: Once the new table is sufficiently populated, the backfilling process can be stopped.
+
+This approach enables efficient warm start and transfer learning with collision-free embedding tables, reducing the computational overhead associated with ID redistribution.
diff --git a/setup.py b/setup.py
index 63d3aa30d..34987b905 100644
--- a/setup.py
+++ b/setup.py
@@ -7,97 +7,114 @@
import argparse
import os
-import random
-import re
+import subprocess
import sys
-from datetime import date
+from pathlib import Path
from typing import List
from setuptools import find_packages, setup
+ROOT_DIR = Path(__file__).parent.resolve()
-def get_version():
- # get version string from version.py
- # TODO: ideally the version.py should be generated when setup is run
- version_file = os.path.join(os.path.dirname(__file__), "version.py")
- version_regex = r"__version__ = ['\"]([^'\"]*)['\"]"
- with open(version_file, "r") as f:
- version = re.search(version_regex, f.read(), re.M).group(1)
- return version
+def _get_version():
+ try:
+ cmd = ["git", "rev-parse", "HEAD"]
+ sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode("ascii").strip()
+ except Exception:
+ sha = None
-def get_nightly_version():
- today = date.today()
- return f"{today.year}.{today.month}.{today.day}"
+ if "BUILD_VERSION" in os.environ:
+ version = os.environ["BUILD_VERSION"]
+ else:
+ with open(os.path.join(ROOT_DIR, "version.txt"), "r") as f:
+ version = f.readline().strip()
+ if sha is not None and "OFFICIAL_RELEASE" not in os.environ:
+ version += "+" + sha[:7]
+
+ if sha is None:
+ sha = "Unknown"
+ return version, sha
+
+
+def _export_version(version, sha):
+ version_path = ROOT_DIR / "torchrec" / "version.py"
+ with open(version_path, "w") as fileobj:
+ fileobj.write("__version__ = '{}'\n".format(version))
+ fileobj.write("git_version = {}\n".format(repr(sha)))
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchrec setup")
- parser.add_argument(
- "--package_name",
- type=str,
- default="torchrec",
- help="the name of this output wheel",
- )
return parser.parse_known_args(argv)
def main(argv: List[str]) -> None:
args, unknown = parse_args(argv)
- # Set up package name and version
- name = args.package_name
- is_nightly = "nightly" in name
- is_test = "test" in name
-
with open(
os.path.join(os.path.dirname(__file__), "README.MD"), encoding="utf8"
) as f:
readme = f.read()
with open(
- os.path.join(os.path.dirname(__file__), "requirements.txt"), encoding="utf8"
+ os.path.join(os.path.dirname(__file__), "install-requirements.txt"),
+ encoding="utf8",
) as f:
reqs = f.read()
install_requires = reqs.strip().split("\n")
- version = get_nightly_version() if is_nightly else get_version()
-
- if not is_nightly:
- if "fbgemm-gpu-nightly" in install_requires:
- install_requires.remove("fbgemm-gpu-nightly")
- install_requires.append("fbgemm-gpu")
-
- if is_test:
- version = (f"0.0.{random.randint(0, 1000)}",)
- print(f"-- {name} building version: {version}")
-
- packages = find_packages(exclude=("*tests",))
+ version, sha = _get_version()
+ _export_version(version, sha)
+
+ print(f"-- torchrec building version: {version}")
+
+ packages = find_packages(
+ exclude=(
+ "*tests",
+ "*test",
+ "examples",
+ "*examples.*",
+ "*benchmarks",
+ "*build",
+ "*rfc",
+ )
+ )
sys.argv = [sys.argv[0]] + unknown
setup(
# Metadata
- name=name,
+ name="torchrec",
version=version,
author="TorchRec Team",
author_email="packages@pytorch.org",
- description="Pytorch domain library for recommendation systems",
+ maintainer="PaulZhang12",
+ maintainer_email="paulzhan@meta.com",
+ description="TorchRec: Pytorch library for recommendation systems",
long_description=readme,
long_description_content_type="text/markdown",
url="https://github.com/pytorch/torchrec",
license="BSD-3",
- keywords=["pytorch", "recommendation systems", "sharding"],
- python_requires=">=3.7",
+ keywords=[
+ "pytorch",
+ "recommendation systems",
+ "sharding",
+ "distributed training",
+ ],
+ python_requires=">=3.9",
install_requires=install_requires,
packages=packages,
zip_safe=False,
# PyPI package information.
classifiers=[
- "Development Status :: 4 - Beta",
+ "Development Status :: 5 - Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
diff --git a/test/cpp/dynamic_embedding/CMakeLists.txt b/test/cpp/dynamic_embedding/CMakeLists.txt
index 056bfd985..53d2c7c02 100644
--- a/test/cpp/dynamic_embedding/CMakeLists.txt
+++ b/test/cpp/dynamic_embedding/CMakeLists.txt
@@ -14,3 +14,8 @@ add_tde_test(bits_op_test bits_op_test.cpp)
add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp)
add_tde_test(random_bits_generator_test random_bits_generator_test.cpp)
add_tde_test(mixed_lfu_lru_strategy_test mixed_lfu_lru_strategy_test.cpp)
+add_tde_test(notification_test notification_test.cpp)
+
+if (BUILD_REDIS_IO)
+ add_subdirectory(redis)
+endif()
diff --git a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
index f8526e575..456d07e1e 100644
--- a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
+++ b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
@@ -12,8 +12,8 @@
namespace torchrec {
TEST(TDE, order) {
MixedLFULRUStrategy::Record a;
- a.time_ = 1;
- a.freq_power_ = 31;
+ a.time = 1;
+ a.freq_power = 31;
uint32_t i32 = a.ToUint32();
ASSERT_EQ(0xF8000001, i32);
}
@@ -23,42 +23,41 @@ TEST(TDE, MixedLFULRUStrategy_Evict) {
{
records.emplace_back();
records.back().first = 1;
- records.back().second.time_ = 100;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 100;
+ records.back().second.freq_power = 2;
}
{
records.emplace_back();
records.back().first = 2;
- records.back().second.time_ = 10;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 10;
+ records.back().second.freq_power = 2;
}
{
records.emplace_back();
records.back().first = 3;
- records.back().second.time_ = 100;
- records.back().second.freq_power_ = 1;
+ records.back().second.time = 100;
+ records.back().second.freq_power = 1;
}
{
records.emplace_back();
records.back().first = 4;
- records.back().second.time_ = 150;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 150;
+ records.back().second.freq_power = 2;
}
size_t offset_{0};
- auto ids = MixedLFULRUStrategy::evict(
- [&offset_,
- &records]() -> std::optional {
+ MixedLFULRUStrategy strategy;
+ auto ids = strategy.evict(
+ [&offset_, &records]() -> std::optional {
if (offset_ == records.size()) {
return std::nullopt;
}
auto record = records[offset_++];
- MixedLFULRUStrategy::lxu_record_t ext_type =
- *reinterpret_cast(
- &record.second);
- return MixedLFULRUStrategy::transformer_record_t{
- .global_id_ = record.first,
- .cache_id_ = 0,
- .lxu_record_ = ext_type,
+ lxu_record_t ext_type =
+ *reinterpret_cast(&record.second);
+ return record_t{
+ .global_id = record.first,
+ .cache_id = 0,
+ .lxu_record = ext_type,
};
},
3);
@@ -73,12 +72,12 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
constexpr static size_t n_iter = 1000000;
MixedLFULRUStrategy strategy;
strategy.update_time(10);
- MixedLFULRUStrategy::lxu_record_t val;
+ lxu_record_t val;
{
val = strategy.update(0, 0, std::nullopt);
auto record = reinterpret_cast(&val);
- ASSERT_EQ(record->freq_power_, 5);
- ASSERT_EQ(record->time_, 10);
+ ASSERT_EQ(record->freq_power, 5);
+ ASSERT_EQ(record->time, 10);
}
uint32_t freq_power_5_cnt = 0;
@@ -87,13 +86,13 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
for (size_t i = 0; i < n_iter; ++i) {
auto tmp = strategy.update(0, 0, val);
auto record = reinterpret_cast(&tmp);
- ASSERT_EQ(record->time_, 10);
- if (record->freq_power_ == 5) {
+ ASSERT_EQ(record->time, 10);
+ if (record->freq_power == 5) {
++freq_power_5_cnt;
- } else if (record->freq_power_ == 6) {
+ } else if (record->freq_power == 6) {
++freq_power_6_cnt;
} else {
- ASSERT_TRUE(record->freq_power_ == 5 || record->freq_power_ == 6);
+ ASSERT_TRUE(record->freq_power == 5 || record->freq_power == 6);
}
}
diff --git a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
index a037e7b89..f8ba6e82a 100644
--- a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
+++ b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
@@ -12,8 +12,7 @@
namespace torchrec {
TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(16);
+ NaiveIDTransformer> transformer(16);
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {0, 1, 0, 2, 1};
@@ -24,8 +23,7 @@ TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
}
TEST(tde, NaiveThreadedIDTransformer_Full) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(4);
+ NaiveIDTransformer> transformer(4);
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {0, 1, 2, 3, -1};
@@ -37,8 +35,7 @@ TEST(tde, NaiveThreadedIDTransformer_Full) {
}
TEST(tde, NaiveThreadedIDTransformer_Evict) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(4);
+ NaiveIDTransformer> transformer(4);
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
int64_t cache_ids[5];
@@ -60,8 +57,7 @@ TEST(tde, NaiveThreadedIDTransformer_Evict) {
}
TEST(tde, NaiveThreadedIDTransformer_Iterator) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(16);
+ NaiveIDTransformer> transformer(16);
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {3, 4, 3, 5, 4};
diff --git a/test/cpp/dynamic_embedding/notification_test.cpp b/test/cpp/dynamic_embedding/notification_test.cpp
new file mode 100644
index 000000000..f916678d8
--- /dev/null
+++ b/test/cpp/dynamic_embedding/notification_test.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+namespace torchrec {
+TEST(TDE, notification) {
+ Notification notification;
+ std::thread th([&] { notification.done(); });
+ notification.wait();
+ th.join();
+}
+} // namespace torchrec
diff --git a/test/cpp/dynamic_embedding/redis/CMakeLists.txt b/test/cpp/dynamic_embedding/redis/CMakeLists.txt
new file mode 100644
index 000000000..b1c2f0c7b
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+function(add_redis_test NAME)
+ add_executable(${NAME} ${ARGN})
+ target_link_libraries(${NAME} redis_io tde_cpp_objs gtest gtest_main)
+ add_test(NAME ${NAME} COMMAND ${NAME})
+endfunction()
+
+# TODO: Need start a empty redis-server
+# on 127.0.0.1:6379 before run *redis*_test.
+add_redis_test(redis_io_test redis_io_test.cpp)
+add_redis_test(url_test url_test.cpp)
diff --git a/test/cpp/dynamic_embedding/redis/redis_io_test.cpp b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp
new file mode 100644
index 000000000..f5c77c257
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+namespace torchrec::redis {
+
+TEST(TDE, redis_Option) {
+ auto opt = parse_option(
+ "192.168.3.1:3948/?db=3&&num_threads=2&&timeout=3s&&chunk_size=3000");
+ ASSERT_EQ(opt.host, "192.168.3.1");
+ ASSERT_EQ(opt.port, 3948);
+ ASSERT_EQ(opt.db, 3);
+ ASSERT_EQ(opt.num_io_threads, 2);
+ ASSERT_EQ(opt.chunk_size, 3000);
+ ASSERT_EQ(opt.timeout_ms, 3000);
+ ASSERT_TRUE(opt.prefix.empty());
+}
+
+TEST(TDE, redis_Option_ParseError) {
+ ASSERT_ANY_THROW(
+ parse_option("192.168.3.1:3948/?db=3&&no_opt=3000&&num_threads=2"));
+ ASSERT_ANY_THROW(parse_option("192.168.3.1:3948/?timeout=3d"));
+}
+
+struct FetchContext {
+ Notification* notification_;
+ std::function on_data_;
+};
+
+TEST(TDE, redis_push_fetch) {
+ auto opt = parse_option("127.0.0.1:6379");
+ Redis redis(opt);
+
+ constexpr static int64_t global_ids[] = {1, 3, 4};
+ constexpr static uint32_t os_ids[] = {0};
+ constexpr static float params[] = {1, 2, 3, 4, 5, 9, 8, 1};
+ constexpr static uint64_t offsets[] = {
+ 0 * sizeof(float),
+ 2 * sizeof(float),
+ 4 * sizeof(float),
+ 6 * sizeof(float),
+ 8 * sizeof(float)};
+
+ Notification notification;
+
+ IOPushParameter push{
+ .table_name = "table",
+ .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]),
+ .global_ids = global_ids,
+ .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]),
+ .optimizer_state_ids = os_ids,
+ .num_offsets = sizeof(offsets) / sizeof(offsets[0]),
+ .offsets = offsets,
+ .data = params,
+ .on_complete_context = ¬ification,
+ .on_push_complete =
+ +[](void* ctx) {
+ auto* notification = reinterpret_cast(ctx);
+ notification->done();
+ },
+ };
+ redis.push(push);
+
+ notification.wait();
+
+ notification.clear();
+
+ FetchContext ctx{
+ .notification_ = ¬ification,
+ .on_data_ =
+ [&](uint32_t offset, uint32_t os_id, void* data, uint32_t len) {
+ ASSERT_EQ(os_id, 0);
+ uint32_t param_len = 2;
+ ASSERT_EQ(len, sizeof(float) * param_len);
+ auto actual =
+ std::span(reinterpret_cast(data), 2);
+
+ auto expect = std::span(
+ reinterpret_cast(¶ms[offset * param_len]), 2);
+
+ ASSERT_EQ(expect[0], actual[0]);
+ ASSERT_EQ(expect[1], actual[1]);
+ }};
+
+ IOFetchParameter fetch{
+ .table_name = "table",
+ .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]),
+ .global_ids = global_ids,
+ .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]),
+ .on_complete_context = &ctx,
+ .on_global_id_fetched =
+ +[](void* ctx,
+ uint32_t offset,
+ uint32_t os_id,
+ void* data,
+ uint32_t len) {
+ auto c = reinterpret_cast(ctx);
+ c->on_data_(offset, os_id, data, len);
+ },
+ .on_all_fetched =
+ +[](void* ctx) {
+ auto c = reinterpret_cast(ctx);
+ c->notification_->done();
+ }};
+ redis.fetch(fetch);
+ notification.wait();
+}
+} // namespace torchrec::redis
diff --git a/test/cpp/dynamic_embedding/redis/url_test.cpp b/test/cpp/dynamic_embedding/redis/url_test.cpp
new file mode 100644
index 000000000..6f37744c7
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/url_test.cpp
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+
+namespace torchrec::url_parser::rules {
+
+TEST(TDE, url) {
+ auto url = parse_url("http://github.com/www.qq.com/?a=b&&c=d");
+ ASSERT_EQ(url.host, "www.qq.com");
+ ASSERT_TRUE(url.param.has_value());
+ ASSERT_EQ("a=b&&c=d", url.param.value());
+}
+
+} // namespace torchrec::url_parser::rules
diff --git a/test_installation.py b/test_installation.py
index cbef42766..48328a389 100644
--- a/test_installation.py
+++ b/test_installation.py
@@ -18,6 +18,7 @@
from torchrec.models.dlrm import DLRM
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
if sys.platform not in ["linux", "linux2"]:
raise EnvironmentError(
@@ -118,7 +119,7 @@ def main(argv: List[str]) -> None:
device=device,
)
optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.SGD(params, lr=0.01),
)
diff --git a/tools/lint/black_linter.py b/tools/lint/black_linter.py
index cfdc3d4e8..7c9a75f9c 100644
--- a/tools/lint/black_linter.py
+++ b/tools/lint/black_linter.py
@@ -176,11 +176,11 @@ def main() -> None:
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
- level=logging.NOTSET
- if args.verbose
- else logging.DEBUG
- if len(args.filenames) < 1000
- else logging.INFO,
+ level=(
+ logging.NOTSET
+ if args.verbose
+ else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
+ ),
stream=sys.stderr,
)
diff --git a/torchrec/__init__.py b/torchrec/__init__.py
index ddead68c9..29258f6f0 100644
--- a/torchrec/__init__.py
+++ b/torchrec/__init__.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import torchrec.distributed # noqa
import torchrec.quant # noqa
from torchrec.fx import tracer # noqa
@@ -25,3 +27,10 @@
KeyedTensor,
)
from torchrec.streamable import Multistreamable, Pipelineable # noqa
+
+try:
+ # pyre-ignore[21]
+ # @manual=//torchrec/fb:version
+ from .version import __version__, github_version # noqa
+except ImportError:
+ pass
diff --git a/torchrec/csrc/dynamic_embedding/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/CMakeLists.txt
index ca6fa20b5..f3bd73ad5 100644
--- a/torchrec/csrc/dynamic_embedding/CMakeLists.txt
+++ b/torchrec/csrc/dynamic_embedding/CMakeLists.txt
@@ -6,12 +6,19 @@
add_library(tde_cpp_objs
OBJECT
+ bind.cpp
+ id_transformer_wrapper.cpp
+ ps.cpp
details/clz_impl.cpp
details/ctz_impl.cpp
details/random_bits_generator.cpp
- details/mixed_lfu_lru_strategy.cpp
details/io_registry.cpp
- details/io.cpp)
+ details/io.cpp
+ details/notification.cpp)
+
+if (BUILD_REDIS_IO)
+ add_subdirectory(details/redis)
+endif()
target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS})
diff --git a/torchrec/csrc/dynamic_embedding/bind.cpp b/torchrec/csrc/dynamic_embedding/bind.cpp
new file mode 100644
index 000000000..e2faa9bbd
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/bind.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+#include
+#include
+#include
+
+namespace torchrec {
+TORCH_LIBRARY(tde, m) {
+ m.def("register_io", [](const std::string& name) {
+ IORegistry::Instance().register_plugin(name.c_str());
+ });
+
+ m.class_("TransformResult")
+ .def_readonly("success", &TransformResult::success)
+ .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch);
+
+ m.class_("IDTransformer")
+ .def(torch::init())
+ .def("transform", &IDTransformerWrapper::transform)
+ .def("evict", &IDTransformerWrapper::evict)
+ .def("save", &IDTransformerWrapper::save);
+
+ m.class_("LocalShardList")
+ .def(torch::init([]() { return c10::make_intrusive(); }))
+ .def("append", &LocalShardList::emplace_back);
+
+ m.class_("FetchHandle").def("wait", &FetchHandle::wait);
+
+ m.class_("PS")
+ .def(torch::init<
+ std::string,
+ c10::intrusive_ptr,
+ int64_t,
+ int64_t,
+ std::string,
+ int64_t>())
+ .def("fetch", &PS::fetch)
+ .def("evict", &PS::evict);
+}
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/id_transformer.h b/torchrec/csrc/dynamic_embedding/details/id_transformer.h
new file mode 100644
index 000000000..3c826c840
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/id_transformer.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include
+#include
+
+namespace torchrec {
+
+namespace transform_default {
+inline lxu_record_t no_update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional record) {
+ return record.value_or(lxu_record_t{});
+};
+
+inline void no_fetch(int64_t global_id, int64_t cache_id) {}
+} // namespace transform_default
+
+class IDTransformer {
+ public:
+ /**
+ * Transform global ids to cache ids
+ *
+ * @tparam Update Update the eviction strategy tag type. Update LXU Record
+ * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used
+ * by dynamic embedding parameter server.
+ *
+ * @param global_ids Global ID vector
+ * @param cache_ids [out] Cache ID vector
+ * @param update update lambda. See `Update` doc.
+ * @param fetch fetch lambda. See `Fetch` doc.
+ * @return true if all transformed, otherwise need eviction.
+ */
+ virtual bool transform(
+ std::span global_ids,
+ std::span cache_ids,
+ update_t update = transform_default::no_update,
+ fetch_t fetch = transform_default::no_fetch) = 0;
+
+ /**
+ * Evict global ids from the transformer
+ *
+ * @param global_ids Global IDs to evict.
+ */
+ virtual void evict(std::span global_ids) = 0;
+
+ /**
+ * Create an iterator of the id transformer, a possible usecase is:
+ *
+ * auto iterator = transformer.iterator();
+ * auto record = iterator();
+ * while (record.has_value()) {
+ * // do sth with the record
+ * // ...
+ * // get next record
+ * auto record = iterator();
+ * }
+ *
+ * @return the iterator created.
+ */
+ virtual iterator_t iterator() const = 0;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io.cpp b/torchrec/csrc/dynamic_embedding/details/io.cpp
index cf9556d7e..1c205254b 100644
--- a/torchrec/csrc/dynamic_embedding/details/io.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/io.cpp
@@ -8,7 +8,7 @@
#include
-namespace tde::details {
+namespace torchrec {
static constexpr std::string_view k_schema_separator = "://";
@@ -77,7 +77,7 @@ static void on_all_fetched(void* ctx) {
delete c;
}
-void IO::pull(
+void IO::fetch(
const std::string& table_name,
std::span global_ids,
std::span col_ids,
@@ -93,7 +93,7 @@ void IO::pull(
ctx->tensors.resize(
global_ids.size() * std::max(col_ids.size(), static_cast(1)));
- IOPullParameter param{
+ IOFetchParameter param{
.table_name = table_name.c_str(),
.num_cols = static_cast(col_ids.size()),
.num_global_ids = static_cast(global_ids.size()),
@@ -104,7 +104,7 @@ void IO::pull(
.on_all_fetched = on_all_fetched,
};
param.on_complete_context = ctx.release();
- provider_.pull(instance_, param);
+ provider_.fetch(instance_, param);
}
struct PushContext {
@@ -135,7 +135,7 @@ void IO::push(
.col_ids = col_ids.data(),
.global_ids = global_ids.data(),
.num_optimizer_states = static_cast(os_ids.size()),
- .optimizer_stats_ids = os_ids.data(),
+ .optimizer_state_ids = os_ids.data(),
.num_offsets = static_cast(offsets.size()),
.offsets = offsets.data(),
.data = data.data(),
@@ -145,4 +145,4 @@ void IO::push(
provider_.push(instance_, param);
}
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io.h b/torchrec/csrc/dynamic_embedding/details/io.h
index ae2d19195..2f4ec64ac 100644
--- a/torchrec/csrc/dynamic_embedding/details/io.h
+++ b/torchrec/csrc/dynamic_embedding/details/io.h
@@ -12,7 +12,7 @@
#include
#include
-namespace tde::details {
+namespace torchrec {
class IO {
public:
@@ -44,7 +44,7 @@ class IO {
* copied inside, so it is safe to free `col_ids`/`global_ids` before
* `on_fetch_complete`.
*/
- void pull(
+ void fetch(
const std::string& table_name,
std::span global_ids,
std::span col_ids,
@@ -89,4 +89,4 @@ class IO {
void* instance_{};
};
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_parameter.h b/torchrec/csrc/dynamic_embedding/details/io_parameter.h
new file mode 100644
index 000000000..ab92db9cd
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/io_parameter.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+namespace torchrec {
+
+using GlobalIDFetchCallback = void (*)(
+ void* ctx,
+ uint32_t offset,
+ uint32_t optimizer_state,
+ void* data,
+ uint32_t data_len);
+
+struct IOFetchParameter {
+ const char* table_name;
+ uint32_t num_cols;
+ uint32_t num_global_ids;
+ const int64_t* col_ids;
+ const int64_t* global_ids;
+ uint32_t num_optimizer_states;
+ void* on_complete_context;
+ GlobalIDFetchCallback on_global_id_fetched;
+ void (*on_all_fetched)(void* ctx);
+};
+
+struct IOPushParameter {
+ const char* table_name;
+ uint32_t num_cols;
+ uint32_t num_global_ids;
+ const int64_t* col_ids;
+ const int64_t* global_ids;
+ uint32_t num_optimizer_states;
+ const uint32_t* optimizer_state_ids;
+ uint32_t num_offsets;
+ const uint64_t* offsets;
+ const void* data;
+ void* on_complete_context;
+ void (*on_push_complete)(void* ctx);
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
index 614da6d2d..b3744dedf 100644
--- a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
@@ -10,7 +10,7 @@
#include
#include
-namespace tde::details {
+namespace torchrec {
void IORegistry::register_provider(IOProvider provider) {
std::string type = provider.type;
@@ -41,9 +41,9 @@ void IORegistry::register_plugin(const char* filename) {
provider.finalize =
reinterpret_cast(finalize_ptr);
- auto pull_ptr = dlsym(ptr.get(), "IO_Pull");
- TORCH_CHECK(pull_ptr != nullptr, "cannot find IO_Pull symbol");
- provider.pull = reinterpret_cast(pull_ptr);
+ auto fetch_ptr = dlsym(ptr.get(), "IO_Fetch");
+ TORCH_CHECK(fetch_ptr != nullptr, "cannot find IO_Fetch symbol");
+ provider.fetch = reinterpret_cast(fetch_ptr);
auto push_ptr = dlsym(ptr.get(), "IO_Push");
TORCH_CHECK(push_ptr != nullptr, "cannot find IO_Push symbol");
@@ -72,4 +72,4 @@ IORegistry& IORegistry::Instance() {
return instance;
}
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.h b/torchrec/csrc/dynamic_embedding/details/io_registry.h
index af727e84a..3feb39f4a 100644
--- a/torchrec/csrc/dynamic_embedding/details/io_registry.h
+++ b/torchrec/csrc/dynamic_embedding/details/io_registry.h
@@ -10,48 +10,17 @@
#include
#include
#include
+#include
#include
#include
#include
-namespace tde::details {
-
-struct IOPullParameter {
- const char* table_name;
- uint32_t num_cols;
- uint32_t num_global_ids;
- const int64_t* col_ids;
- const int64_t* global_ids;
- uint32_t num_optimizer_states;
- void* on_complete_context;
- void (*on_global_id_fetched)(
- void* ctx,
- uint32_t offset,
- uint32_t optimizer_state,
- void* data,
- uint32_t data_len);
- void (*on_all_fetched)(void* ctx);
-};
-
-struct IOPushParameter {
- const char* table_name;
- uint32_t num_cols;
- uint32_t num_global_ids;
- const int64_t* col_ids;
- const int64_t* global_ids;
- uint32_t num_optimizer_states;
- const uint32_t* optimizer_stats_ids;
- uint32_t num_offsets;
- const uint64_t* offsets;
- const void* data;
- void* on_complete_context;
- void (*on_push_complete)(void* ctx);
-};
+namespace torchrec {
struct IOProvider {
const char* type;
void* (*initialize)(const char* cfg);
- void (*pull)(void* instance, IOPullParameter cfg);
+ void (*fetch)(void* instance, IOFetchParameter cfg);
void (*push)(void* instance, IOPushParameter cfg);
void (*finalize)(void*);
};
@@ -75,4 +44,4 @@ class IORegistry {
std::vector dls_;
};
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h
new file mode 100644
index 000000000..e18107770
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+
+namespace torchrec {
+
+class LXUStrategy {
+ public:
+ LXUStrategy() = default;
+ LXUStrategy(const LXUStrategy&) = delete;
+ LXUStrategy(LXUStrategy&& o) noexcept = default;
+
+ virtual void update_time(uint32_t time) = 0;
+ virtual int64_t time(lxu_record_t record) = 0;
+
+ virtual lxu_record_t update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional val) = 0;
+
+ /**
+ * Analysis all ids and returns the num_elems that are most need to evict.
+ * @param iterator Returns each global_id to ExtValue pair. Returns nullopt
+ * when at ends.
+ * @param num_to_evict
+ * @return
+ */
+ virtual std::vector evict(
+ iterator_t iterator,
+ uint64_t num_to_evict) = 0;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp
deleted file mode 100644
index f6c486726..000000000
--- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-#include
-#include
-#include
-
-namespace torchrec {
-MixedLFULRUStrategy::MixedLFULRUStrategy(uint16_t min_used_freq_power)
- : min_lfu_power_(min_used_freq_power), time_(new std::atomic()) {}
-
-void MixedLFULRUStrategy::update_time(uint32_t time) {
- time_->store(time);
-}
-
-MixedLFULRUStrategy::lxu_record_t MixedLFULRUStrategy::update(
- int64_t global_id,
- int64_t cache_id,
- std::optional val) {
- Record r{};
- r.time_ = time_->load();
-
- if (C10_UNLIKELY(!val.has_value())) {
- r.freq_power_ = min_lfu_power_;
- } else {
- auto freq_power = reinterpret_cast(&val.value())->freq_power_;
- bool should_carry = generator_.is_next_n_bits_all_zero(freq_power);
- if (should_carry) {
- ++freq_power;
- }
- r.freq_power_ = freq_power;
- }
- return *reinterpret_cast(&r);
-}
-
-} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
index 85f7bbd75..beb015c18 100644
--- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
+++ b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
@@ -7,7 +7,8 @@
*/
#pragma once
-#include
+#include
+#include
#include
#include
#include
@@ -37,49 +38,56 @@ namespace torchrec {
*
* Use `update` to update extended value when every time global id that used.
*/
-class MixedLFULRUStrategy {
+class MixedLFULRUStrategy : public LXUStrategy {
public:
- using lxu_record_t = uint32_t;
- using transformer_record_t = TransformerRecord;
-
- static constexpr std::string_view type_ = "mixed_lru_lfu";
-
/**
* @param min_used_freq_power min usage is 2^min_used_freq_power. Set this to
* avoid recent values evict too fast.
*/
- explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5);
+ explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5)
+ : min_lfu_power_(min_used_freq_power),
+ time_(new std::atomic()) {}
MixedLFULRUStrategy(const MixedLFULRUStrategy&) = delete;
MixedLFULRUStrategy(MixedLFULRUStrategy&& o) noexcept = default;
- void update_time(uint32_t time);
- template
- static int64_t time(T record) {
- static_assert(sizeof(T) == sizeof(Record));
- return static_cast(reinterpret_cast(&record)->time_);
+ void update_time(lxu_record_t time) override {
+ time_->store(time);
+ }
+
+ int64_t time(lxu_record_t record) override {
+ return static_cast(reinterpret_cast(&record)->time);
}
- lxu_record_t
- update(int64_t global_id, int64_t cache_id, std::optional val);
+ lxu_record_t update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional val) override {
+ Record r{};
+ r.time = time_->load();
+
+ if (!val.has_value()) [[unlikely]] {
+ r.freq_power = min_lfu_power_;
+ } else {
+ auto freq_power = reinterpret_cast(&val.value())->freq_power;
+ bool should_carry = generator_.is_next_n_bits_all_zero(freq_power);
+ if (should_carry) {
+ ++freq_power;
+ }
+ r.freq_power = freq_power;
+ }
+ return *reinterpret_cast(&r);
+ }
struct EvictItem {
- int64_t global_id_;
- lxu_record_t record_;
+ int64_t global_id;
+ lxu_record_t record;
bool operator<(const EvictItem& item) const {
- return record_ < item.record_;
+ return record < item.record;
}
};
- /**
- * Analysis all ids and returns the num_elems that are most need to evict.
- * @param iterator Returns each global_id to ExtValue pair. Returns nullopt
- * when at ends.
- * @param num_to_evict
- * @return
- */
- template
- static std::vector evict(Iterator iterator, uint64_t num_to_evict) {
+ std::vector evict(iterator_t iterator, uint64_t num_to_evict) {
std::priority_queue items;
while (true) {
auto val = iterator();
@@ -87,8 +95,8 @@ class MixedLFULRUStrategy {
break;
}
EvictItem item{
- .global_id_ = val->global_id_,
- .record_ = reinterpret_cast(&val->lxu_record_)->ToUint32(),
+ .global_id = val->global_id,
+ .record = reinterpret_cast(&val->lxu_record)->ToUint32(),
};
if (items.size() == num_to_evict) {
if (!(item < items.top())) {
@@ -105,7 +113,7 @@ class MixedLFULRUStrategy {
result.reserve(items.size());
while (!items.empty()) {
auto item = items.top();
- result.emplace_back(item.global_id_);
+ result.emplace_back(item.global_id);
items.pop();
}
std::reverse(result.begin(), result.end());
@@ -114,17 +122,16 @@ class MixedLFULRUStrategy {
// Record should only be used in unittest or internally.
struct Record {
- uint32_t time_ : 27;
- uint16_t freq_power_ : 5;
+ uint32_t time : 27;
+ uint16_t freq_power : 5;
[[nodiscard]] uint32_t ToUint32() const {
- return time_ | (freq_power_ << (32 - 5));
+ return time | (freq_power << (32 - 5));
}
};
-
- private:
static_assert(sizeof(Record) == sizeof(lxu_record_t));
+ private:
RandomBitsGenerator generator_;
uint16_t min_lfu_power_;
std::unique_ptr> time_;
diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
index 9cf2e694d..8f7718340 100644
--- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
+++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
@@ -9,33 +9,13 @@
#pragma once
#include
#include
+#include
#include
#include
#include
namespace torchrec {
-namespace transform_default {
-
-template
-inline LXURecord no_update(
- std::optional record,
- int64_t global_id,
- int64_t cache_id) {
- return record.value_or(LXURecord{});
-};
-
-inline void no_fetch(int64_t global_id, int64_t cache_id) {}
-
-} // namespace transform_default
-
-template
-struct TransformerRecord {
- int64_t global_id_;
- int64_t cache_id_;
- LXURecord lxu_record_;
-};
-
/**
* NaiveIDTransformer
*
@@ -43,67 +23,27 @@ struct TransformerRecord {
* @tparam LXURecord The extension type used for eviction strategy.
* @tparam Bitmap The bitmap class to record the free cache ids.
*/
-template >
-class NaiveIDTransformer {
+template >
+class NaiveIDTransformer : public IDTransformer {
public:
- using lxu_record_t = LXURecord;
- using record_t = TransformerRecord;
- static constexpr std::string_view type_ = "naive";
-
explicit NaiveIDTransformer(int64_t num_embedding);
- NaiveIDTransformer(const NaiveIDTransformer&) = delete;
- NaiveIDTransformer(NaiveIDTransformer&&) noexcept =
- default;
+ NaiveIDTransformer(const NaiveIDTransformer&) = delete;
+ NaiveIDTransformer(NaiveIDTransformer&&) noexcept = default;
- /**
- * Transform global ids to cache ids
- *
- * @tparam Update Update the eviction strategy tag type. Update LXU Record
- * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used
- * by dynamic embedding parameter server.
- *
- * @param global_ids Global ID vector
- * @param cache_ids [out] Cache ID vector
- * @param update update lambda. See `Update` doc.
- * @param fetch fetch lambda. See `Fetch` doc.
- * @return true if all transformed, otherwise need eviction.
- */
- template <
- typename Update = decltype(transform_default::no_update),
- typename Fetch = decltype(transform_default::no_fetch)>
bool transform(
std::span global_ids,
std::span cache_ids,
- Update update = transform_default::no_update,
- Fetch fetch = transform_default::no_fetch);
+ update_t update = transform_default::no_update,
+ fetch_t fetch = transform_default::no_fetch) override;
- /**
- * Evict global ids from the transformer
- *
- * @param global_ids Global IDs to evict.
- */
- void evict(std::span global_ids);
+ void evict(std::span global_ids) override;
- /**
- * Create an iterator of the id transformer, a possible usecase is:
- *
- * auto iterator = transformer.iterator();
- * auto record = iterator();
- * while (record.has_value()) {
- * // do sth with the record
- * // ...
- * // get next record
- * auto record = iterator();
- * }
- *
- * @return the iterator created.
- */
- std::function()> iterator() const;
+ iterator_t iterator() const override;
private:
struct CacheValue {
- int64_t cache_id_;
- LXURecord lxu_record_;
+ int64_t cache_id;
+ lxu_record_t lxu_record;
};
ska::flat_hash_map global_id2cache_value_;
diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
index cd40f87d7..e0e38b2f7 100644
--- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
+++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
@@ -6,43 +6,40 @@
* LICENSE file in the root directory of this source tree.
*/
-#pragma once
#include
#include
namespace torchrec {
-template
-inline NaiveIDTransformer::NaiveIDTransformer(
- int64_t num_embedding)
+template
+NaiveIDTransformer::NaiveIDTransformer(int64_t num_embedding)
: bitmap_(num_embedding) {
global_id2cache_value_.reserve(num_embedding);
}
-template
-template
-inline bool NaiveIDTransformer::transform(
+template
+bool NaiveIDTransformer::transform(
std::span global_ids,
std::span cache_ids,
- Update update,
- Fetch fetch) {
+ update_t update,
+ fetch_t fetch) {
for (size_t i = 0; i < global_ids.size(); ++i) {
int64_t global_id = global_ids[i];
auto iter = global_id2cache_value_.find(global_id);
// cache_id is in [0, num_embedding)
int64_t cache_id;
if (iter != global_id2cache_value_.end()) {
- cache_id = iter->second.cache_id_;
- iter->second.lxu_record_ =
- update(iter->second.lxu_record_, global_id, cache_id);
+ cache_id = iter->second.cache_id;
+ iter->second.lxu_record =
+ update(global_id, cache_id, iter->second.lxu_record);
} else {
// The transformer is full.
- if (C10_UNLIKELY(bitmap_.full())) {
+ if (bitmap_.full()) [[unlikely]] {
return false;
}
auto stored_cache_id = bitmap_.next_free_bit();
cache_id = stored_cache_id;
- LXURecord record = update(std::nullopt, global_id, cache_id);
+ lxu_record_t record = update(global_id, cache_id, std::nullopt);
global_id2cache_value_.emplace(
global_id, CacheValue{stored_cache_id, record});
fetch(global_id, cache_id);
@@ -52,30 +49,28 @@ inline bool NaiveIDTransformer::transform(
return true;
}
-template
-inline void NaiveIDTransformer::evict(
- std::span global_ids) {
+template
+void NaiveIDTransformer::evict(std::span global_ids) {
for (const int64_t global_id : global_ids) {
auto iter = global_id2cache_value_.find(global_id);
if (iter == global_id2cache_value_.end()) {
continue;
}
- int64_t cache_id = iter->second.cache_id_;
+ int64_t cache_id = iter->second.cache_id;
global_id2cache_value_.erase(iter);
bitmap_.free_bit(cache_id);
}
}
-template
-inline auto NaiveIDTransformer::iterator() const
- -> std::function()> {
+template
+iterator_t NaiveIDTransformer::iterator() const {
auto iter = global_id2cache_value_.begin();
return [iter, this]() mutable -> std::optional {
if (iter != global_id2cache_value_.end()) {
auto record = record_t{
- .global_id_ = iter->first,
- .cache_id_ = iter->second.cache_id_,
- .lxu_record_ = iter->second.lxu_record_,
+ .global_id = iter->first,
+ .cache_id = iter->second.cache_id,
+ .lxu_record = iter->second.lxu_record,
};
iter++;
return record;
diff --git a/torchrec/csrc/dynamic_embedding/details/notification.cpp b/torchrec/csrc/dynamic_embedding/details/notification.cpp
new file mode 100644
index 000000000..297b05393
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/notification.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "notification.h"
+
+namespace torchrec {
+void Notification::done() {
+ {
+ std::lock_guard guard(mu_);
+ set_ = true;
+ }
+ cv_.notify_all();
+}
+void Notification::wait() {
+ std::unique_lock lock(mu_);
+ cv_.wait(lock, [this] { return set_; });
+}
+
+void Notification::clear() {
+ set_ = false;
+}
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/notification.h b/torchrec/csrc/dynamic_embedding/details/notification.h
new file mode 100644
index 000000000..5d6c04f03
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/notification.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include
+
+namespace torchrec {
+
+/**
+ * Multi-thread notification
+ */
+class Notification : public torch::CustomClassHolder {
+ public:
+ Notification() = default;
+
+ void done();
+ void wait();
+
+ /**
+ * Clear the set status.
+ *
+ * NOTE: Clear is not thread-safe.
+ */
+ void clear();
+
+ private:
+ bool set_{false};
+ std::mutex mu_;
+ std::condition_variable cv_;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
index 4b307842d..0fada452d 100644
--- a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
@@ -6,17 +6,16 @@
* LICENSE file in the root directory of this source tree.
*/
-#include
#include
#include
namespace torchrec {
bool BitScanner::is_next_n_bits_all_zero(uint16_t& n_bits) {
- if (C10_UNLIKELY((n_bits == 0))) {
+ if ((n_bits == 0)) [[unlikely]] {
return true;
}
- if (C10_UNLIKELY(array_idx_ == size_)) {
+ if (array_idx_ == size_) [[unlikely]] {
return true;
}
@@ -88,7 +87,7 @@ bool RandomBitsGenerator::is_next_n_bits_all_zero(uint16_t n_bits) {
return false;
}
- if (C10_UNLIKELY(n_bits != 0)) {
+ if (n_bits != 0) [[unlikely]] {
return is_next_n_bits_all_zero(n_bits);
} else {
return true;
diff --git a/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt
new file mode 100644
index 000000000..a9771391c
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+FetchContent_Declare(
+ hiredis
+ GIT_REPOSITORY https://github.com/redis/hiredis.git
+ GIT_TAG 06be7ff312a78f69237e5963cc7d24bc84104d3b
+)
+
+FetchContent_GetProperties(hiredis)
+if(NOT hiredis_POPULATED)
+ # Do not include hiredis in install targets
+ FetchContent_Populate(hiredis)
+ set(DISABLE_TESTS ON CACHE BOOL "Disable tests for hiredis")
+ add_subdirectory(
+ ${hiredis_SOURCE_DIR} ${hiredis_BINARY_DIR} EXCLUDE_FROM_ALL)
+endif()
+
+add_library(redis_io SHARED redis_io.cpp)
+target_include_directories(
+ redis_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../)
+target_include_directories(redis_io PUBLIC ${TORCH_INCLUDE_DIRS})
+target_compile_options(redis_io PUBLIC -fPIC)
+target_link_libraries(redis_io PUBLIC hiredis::hiredis_static)
diff --git a/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp
new file mode 100644
index 000000000..71f04fd08
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp
@@ -0,0 +1,505 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include
+
+namespace torchrec::redis {
+
+int parse_integer(std::string_view param_str, std::string_view param_key) {
+ return std::stoi(std::string(param_str.substr(param_key.size())));
+}
+
+uint32_t parse_duration(
+ std::string_view param_str,
+ std::string_view param_key) {
+ auto param_value = param_str.substr(param_key.size());
+ if (param_value.empty()) {
+ throw std::invalid_argument("no value for " + std::string(param_str));
+ }
+ double duration;
+ if (param_value.ends_with("ms")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 2)));
+ } else if (param_value.ends_with("s")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 1))) *
+ 1000;
+ } else if (param_value.ends_with("m")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 1))) *
+ 1000 * 60;
+ } else {
+ throw std::invalid_argument(
+ "no supported time unit (ms, s, m) in " + std::string(param_str));
+ }
+ return static_cast(duration);
+}
+
+Option parse_option(std::string_view config_str) {
+ Option option;
+ url_parser::Url url = url_parser::parse_url(http://github.com/config_str);
+
+ if (url.authority.has_value()) {
+ option.username = std::move(url.authority->username);
+ option.password = std::move(url.authority->password);
+ }
+
+ option.host = std::move(url.host);
+
+ if (url.port.has_value()) {
+ option.port = url.port.value();
+ }
+
+ if (url.param.has_value()) {
+ std::string_view param_str = url.param.value();
+ while (!param_str.empty()) {
+ auto and_pos = param_str.find("&&");
+ std::string_view single_param_str;
+ if (and_pos != std::string_view::npos) {
+ single_param_str = param_str.substr(0, and_pos);
+ param_str = param_str.substr(and_pos + 2);
+ } else {
+ single_param_str = param_str;
+ param_str = "";
+ }
+
+ if (single_param_str.starts_with("num_threads=")) {
+ option.num_io_threads = parse_integer(single_param_str, "num_threads=");
+ } else if (single_param_str.starts_with("db=")) {
+ option.db = parse_integer(single_param_str, "db=");
+ } else if (single_param_str.starts_with("prefix=")) {
+ option.prefix = single_param_str.substr(std::string("prefix=").size());
+ } else if (single_param_str.starts_with("timeout=")) {
+ option.timeout_ms = parse_duration(single_param_str, "timeout=");
+ } else if (single_param_str.starts_with("heartbeat=")) {
+ option.heart_beat_interval_ms =
+ parse_duration(single_param_str, "heartbeat=");
+ } else if (single_param_str.starts_with("retry_limit=")) {
+ option.retry_limit = parse_integer(single_param_str, "retry_limit=");
+ } else if (single_param_str.starts_with("chunk_size=")) {
+ option.chunk_size = parse_integer(single_param_str, "chunk_size=");
+ } else {
+ throw std::invalid_argument(
+ "unknown parameter: " + std::string(single_param_str));
+ }
+ }
+ }
+
+ return option;
+}
+
+Redis::Redis(Option opt) : opt_(std::move(opt)) {
+ TORCH_CHECK(opt_.num_io_threads != 0, "num_io_threads must not be empty");
+ TORCH_CHECK(
+ opt_.heart_beat_interval_ms != 0,
+ "heart beat interval must not be zero.");
+ for (size_t i = 0; i < opt_.num_io_threads; ++i) {
+ start_thread();
+ }
+}
+
+void Redis::start_thread() {
+ auto connection = connect();
+ heartbeat(connection);
+
+ io_threads_.emplace_back(
+ [connection = std::move(connection), this]() mutable {
+ std::chrono::milliseconds heart_beat(opt_.heart_beat_interval_ms);
+ while (true) {
+ std::function todo;
+ bool heartbeat_timeout;
+ {
+ std::unique_lock lock(this->jobs_mutex_);
+ heartbeat_timeout = !jobs_not_empty_.wait_for(
+ lock, heart_beat, [this] { return !jobs_.empty(); });
+ if (!heartbeat_timeout) {
+ todo = std::move(jobs_.front());
+ jobs_.pop_front();
+ }
+ }
+
+ if (heartbeat_timeout) {
+ heartbeat(connection);
+ continue;
+ }
+
+ if (!todo) {
+ break;
+ }
+ todo(connection);
+ }
+ });
+}
+
+void Redis::heartbeat(helper::ContextPtr& connection) {
+ for (uint32_t retry = 0; retry < opt_.retry_limit; ++retry) {
+ try {
+ auto reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "PING")));
+ TORCH_CHECK(
+ reply && reply->type == REDIS_REPLY_STRING,
+ "Ping should return string");
+ auto rsp = std::string_view(reply->str, reply->len);
+ TORCH_CHECK(rsp == "PONG", "ping/pong error");
+ } catch (...) {
+ // reconnect if heart beat error
+ connection = connect();
+ }
+ }
+}
+
+helper::ContextPtr Redis::connect() const {
+ helper::ContextPtr connection;
+ if (opt_.timeout_ms == 0) {
+ connection = helper::ContextPtr(redisConnect(opt_.host.c_str(), opt_.port));
+ } else {
+ struct timeval interval {};
+ interval.tv_sec = opt_.timeout_ms / 1000;
+ interval.tv_usec = opt_.timeout_ms % 1000 * 1000;
+ connection = helper::ContextPtr(
+ redisConnectWithTimeout(opt_.host.c_str(), opt_.port, interval));
+ }
+ TORCH_CHECK(
+ !connection->err,
+ "connect to %s:%d error occurred %s",
+ opt_.host,
+ opt_.port,
+ connection->errstr);
+
+ if (!opt_.password.empty()) {
+ helper::ReplyPtr reply;
+ if (opt_.username.empty()) {
+ reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "AUTH %s", opt_.password.c_str())));
+ } else {
+ reply = helper::ReplyPtr(reinterpret_cast(redisCommand(
+ connection.get(),
+ "AUTH %s %s",
+ opt_.username.c_str(),
+ opt_.password.c_str())));
+ }
+ check_status("auth error", connection, reply);
+ }
+
+ if (opt_.db != 0) {
+ auto reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "SELECT %d", opt_.db)));
+ check_status("select db error", connection, reply);
+ }
+
+ return connection;
+}
+
+Redis::~Redis() {
+ for (uint32_t i = 0; i < opt_.num_io_threads; ++i) {
+ jobs_.emplace_back();
+ }
+ jobs_not_empty_.notify_all();
+ for (auto& th : io_threads_) {
+ th.join();
+ }
+}
+
+static uint32_t CalculateChunkSizeByGlobalIDs(
+ uint32_t chunk_size,
+ uint32_t num_cols,
+ uint32_t num_os) {
+ static constexpr uint32_t low = 1;
+ return std::max(
+ chunk_size / std::max(num_cols, low) / std::max(num_os, low), low);
+}
+
+struct RedisFetchContext {
+ std::atomic num_complete_ids{0};
+ uint32_t chunk_size;
+ std::string table_name;
+ std::vector global_ids;
+ std::vector col_ids;
+ uint32_t num_optimizer_states;
+ void* on_complete_context;
+ void (*on_global_id_fetched)(
+ void* ctx,
+ uint32_t gid_offset,
+ uint32_t optimizer_state,
+ void* data,
+ uint32_t data_len);
+ void (*on_all_fetched)(void* ctx);
+
+ explicit RedisFetchContext(uint32_t chunk_size, IOFetchParameter param)
+ : chunk_size(CalculateChunkSizeByGlobalIDs(
+ chunk_size,
+ param.num_cols,
+ param.num_optimizer_states)),
+ table_name(param.table_name),
+ global_ids(param.global_ids, param.global_ids + param.num_global_ids),
+ num_optimizer_states(param.num_optimizer_states),
+ on_complete_context(param.on_complete_context),
+ on_global_id_fetched(param.on_global_id_fetched),
+ on_all_fetched(param.on_all_fetched) {
+ if (param.num_cols == 0) {
+ col_ids.emplace_back(-1);
+ } else {
+ col_ids =
+ std::vector(param.col_ids, param.col_ids + param.num_cols);
+ }
+ }
+};
+
+void Redis::fetch(IOFetchParameter param) {
+ auto* fetch_param = new RedisFetchContext(opt_.chunk_size, param);
+ {
+ std::lock_guard guard(this->jobs_mutex_);
+ for (uint32_t i = 0; i < param.num_global_ids;
+ i += fetch_param->chunk_size) {
+ jobs_.emplace_back(
+ [i, fetch_param, this](helper::ContextPtr& connection) {
+ do_fetch(i, fetch_param, connection);
+ });
+ }
+ }
+ jobs_not_empty_.notify_all();
+}
+
+void Redis::do_fetch(
+ uint32_t gid_offset,
+ void* fetch_param_void,
+ helper::ContextPtr& connection) const {
+ auto& fetch_param = *reinterpret_cast(fetch_param_void);
+
+ uint32_t end = std::min(
+ gid_offset + fetch_param.chunk_size,
+ static_cast(fetch_param.global_ids.size()));
+
+ auto loop = [&](auto&& callback) {
+ for (uint32_t i = gid_offset; i < end; ++i) {
+ int64_t gid = fetch_param.global_ids[i];
+ for (uint32_t j = 0; j < fetch_param.col_ids.size(); ++j) {
+ auto& col_id = fetch_param.col_ids[j];
+ for (uint32_t os_id = 0; os_id < fetch_param.num_optimizer_states;
+ ++os_id) {
+ callback(i * fetch_param.col_ids.size() + j, gid, col_id, os_id);
+ }
+ }
+ }
+ };
+
+ loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) {
+ redisAppendCommand(
+ connection.get(),
+ "GET %s_table_%s_gid_%d_cid_%d_osid_%d",
+ opt_.prefix.c_str(),
+ fetch_param.table_name.c_str(),
+ gid,
+ col_id,
+ os_id);
+ });
+
+ void* reply;
+ loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) {
+ int status = redisGetReply(connection.get(), &reply);
+ TORCH_CHECK(
+ status != REDIS_ERR,
+ "get reply error: %s, from redis %s, %d",
+ connection->errstr,
+ opt_.host,
+ opt_.port);
+ auto reply_ptr = helper::ReplyPtr(reinterpret_cast(reply));
+
+ if (reply_ptr->type == REDIS_REPLY_NIL) {
+ fetch_param.on_global_id_fetched(
+ fetch_param.on_complete_context, offset, os_id, nullptr, 0);
+ } else {
+ fetch_param.on_global_id_fetched(
+ fetch_param.on_complete_context,
+ offset,
+ os_id,
+ reply_ptr->str,
+ reply_ptr->len);
+ }
+ });
+
+ uint32_t n = end - gid_offset;
+ uint32_t target = fetch_param.global_ids.size();
+
+ if (fetch_param.num_complete_ids.fetch_add(n) + n ==
+ target) { // last fetch complete
+ fetch_param.on_all_fetched(fetch_param.on_complete_context);
+ delete &fetch_param;
+ }
+}
+
+struct RedisPushContext {
+ std::atomic num_complete_ids{0};
+ uint32_t chunk_size;
+ std::string table_name;
+ std::span global_ids;
+ std::vector col_ids;
+ std::span os_ids;
+ std::span