diff --git a/.github/scripts/filter.py b/.github/scripts/filter.py
new file mode 100644
index 000000000..6285391c1
--- /dev/null
+++ b/.github/scripts/filter.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+
+
+def main():
+ """
+ For filtering out certain cuda versions in the build matrix that
+ determines with nightly builds are run. This ensures TorchRec is
+ always consistent in version compatibility with FBGEMM.
+ """
+
+ full_matrix_string = os.environ["MAT"]
+ full_matrix = json.loads(full_matrix_string)
+ """
+ Matrix contents can be found in a github "Build Linux Wheels"
+ log output. A typical example is
+ {
+ "include": [
+ {
+ "python_version": "3.9",
+ "gpu_arch_type": "cpu",
+ "gpu_arch_version": "",
+ "desired_cuda": "cpu",
+ "container_image": "pytorch/manylinux2_28-builder:cpu",
+ "package_type": "manywheel",
+ "build_name": "manywheel-py3_9-cpu",
+ "validation_runner": "linux.2xlarge",
+ "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu",
+ "channel": "nightly",
+ "upload_to_base_bucket": "no",
+ "stable_version": "2.7.1",
+ "use_split_build": false
+ },
+ {
+ "python_version": "3.9",
+ "gpu_arch_type": "cuda",
+ "gpu_arch_version": "12.6",
+ "desired_cuda": "cu126",
+ "container_image": "pytorch/manylinux2_28-builder:cuda12.6",
+ "package_type": "manywheel",
+ "build_name": "manywheel-py3_9-cuda12_6",
+ "validation_runner": "linux.g5.4xlarge.nvidia.gpu",
+ "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126",
+ "channel": "nightly",
+ "upload_to_base_bucket": "no",
+ "stable_version": "2.7.1",
+ "use_split_build": false
+ }
+ ]
+ }
+ """
+
+ new_matrix_entries = []
+
+ for entry in full_matrix["include"]:
+ if entry["desired_cuda"] == "cu129":
+ continue
+ new_matrix_entries.append(entry)
+
+ new_matrix = {"include": new_matrix_entries}
+ print(json.dumps(new_matrix))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/install_libs.sh b/.github/scripts/install_libs.sh
new file mode 100644
index 000000000..193aaf46d
--- /dev/null
+++ b/.github/scripts/install_libs.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+echo "CU_VERSION: ${CU_VERSION}"
+echo "CHANNEL: ${CHANNEL}"
+echo "CONDA_ENV: ${CONDA_ENV}"
+
+if [[ $CU_VERSION = cu* ]]; then
+ # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not
+ # being able to locate libnvrtc.so
+ echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+ conda env config vars set -p ${CONDA_ENV} \
+ LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
+else
+ echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+ conda env config vars set -p ${CONDA_ENV} \
+ LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
+fi
+
+if [ "$CHANNEL" = "nightly" ]; then
+ ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/"$CU_VERSION"
+elif [ "$CHANNEL" = "test" ]; then
+ ${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION"
+fi
+
+
+${CONDA_RUN} pip install importlib-metadata click PyYAML
diff --git a/.github/scripts/tests_to_skip.txt b/.github/scripts/tests_to_skip.txt
new file mode 100644
index 000000000..7c0e95f00
--- /dev/null
+++ b/.github/scripts/tests_to_skip.txt
@@ -0,0 +1 @@
+_disabled_in_oss_compatibility
diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh
new file mode 100755
index 000000000..c36ff6f6b
--- /dev/null
+++ b/.github/scripts/validate_binaries.sh
@@ -0,0 +1,171 @@
+#!/usr/bin/env bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+export PYTORCH_CUDA_PKG=""
+export CONDA_ENV="build_binary"
+
+if [[ ${MATRIX_PYTHON_VERSION} = '3.13t' ]]; then
+ # use conda-forge to install python3.13t
+ conda create -y -n "${CONDA_ENV}" python="3.13" python-freethreading -c conda-forge
+ conda run -n "${CONDA_ENV}" python -c "import sys; print(f'python GIL enabled: {sys._is_gil_enabled()}')"
+else
+ conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}"
+fi
+
+conda run -n "${CONDA_ENV}" python --version
+
+# Install pytorch, torchrec and fbgemm as per
+# installation instructions on following page
+# https://github.com/pytorch/torchrec#installations
+
+
+# figure out CUDA VERSION
+if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+ if [[ ${MATRIX_GPU_ARCH_VERSION} = '11.8' ]]; then
+ export CUDA_VERSION="cu118"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.1' ]]; then
+ export CUDA_VERSION="cu121"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.6' ]]; then
+ export CUDA_VERSION="cu126"
+ elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.8' ]]; then
+ export CUDA_VERSION="cu128"
+ else
+ export CUDA_VERSION="cu126"
+ fi
+else
+ export CUDA_VERSION="cpu"
+fi
+
+# figure out URL
+if [[ ${MATRIX_CHANNEL} = 'nightly' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/nightly/$%7BCUDA_VERSION%7D"
+elif [[ ${MATRIX_CHANNEL} = 'test' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/test/$%7BCUDA_VERSION%7D"
+elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then
+ export PYTORCH_URL="https://download.pytorch.org/whl/$%7BCUDA_VERSION%7D"
+fi
+
+
+echo "CU_VERSION: ${CUDA_VERSION}"
+echo "MATRIX_CHANNEL: ${MATRIX_CHANNEL}"
+echo "CONDA_ENV: ${CONDA_ENV}"
+
+# shellcheck disable=SC2155
+export CONDA_PREFIX=$(conda run -n "${CONDA_ENV}" printenv CONDA_PREFIX)
+
+
+# Set LD_LIBRARY_PATH to fix the runtime error with fbgemm_gpu not
+# being able to locate libnvrtc.so
+# NOTE: The order of the entries in LD_LIBRARY_PATH matters
+echo "[NOVA] Setting LD_LIBRARY_PATH ..."
+conda env config vars set -n ${CONDA_ENV} \
+ LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:/usr/local/lib:/usr/lib64:${LD_LIBRARY_PATH}"
+
+
+# install pytorch
+# switch back to conda once torch nightly is fixed
+# if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}"
+# fi
+
+conda run -n "${CONDA_ENV}" pip install torch --index-url "$PYTORCH_URL"
+
+# install fbgemm
+conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL"
+
+# install tensordict from pypi
+conda run -n "${CONDA_ENV}" pip install tensordict==0.8.1
+
+# install torchrec
+conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL"
+
+# install other requirements
+conda run -n "${CONDA_ENV}" pip install -r requirements.txt
+
+# Run small import test
+conda run -n "${CONDA_ENV}" python -c "import torch; import fbgemm_gpu; import torchrec"
+
+# check directory
+ls -R
+
+# check if cuda available
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())"
+
+# check cuda version
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)"
+
+# Finally run smoke test
+# python 3.11 needs torchx-nightly
+conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath
+if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
+ conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py
+else
+ conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --script test_installation.py -- --cpu_only
+fi
+
+
+# redo for pypi release
+
+if [[ ${MATRIX_CHANNEL} != 'release' ]]; then
+ exit 0
+else
+ # Check version matches only for release binaries
+ torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2)
+ fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2)
+
+ if [ "$torchrec_version" != "$fbgemm_version" ]; then
+ echo "Error: TorchRec package version does not match FBGEMM package version"
+ exit 1
+ fi
+fi
+
+if [[ ${MATRIX_PYTHON_VERSION} = '3.13t' ]]; then
+ exit 0 # fbgemm-gpu can't support python=3.13t in PYPI
+ # use conda-forge to install python3.13t
+ conda create -y -n "${CONDA_ENV}" python="3.13" python-freethreading -c conda-forge
+ conda run -n "${CONDA_ENV}" python -c "import sys; print(f'python GIL enabled: {sys._is_gil_enabled()}')"
+else
+ conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}"
+fi
+
+
+conda run -n "${CONDA_ENV}" python --version
+
+# we only have one cuda version for pypi build
+if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.6' ]]; then
+ exit 0
+fi
+
+echo "checking pypi release"
+conda run -n "${CONDA_ENV}" pip install torch
+conda run -n "${CONDA_ENV}" pip install fbgemm-gpu
+conda run -n "${CONDA_ENV}" pip install torchrec
+
+# Check version matching again for PyPI
+torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2)
+fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2)
+
+if [ "$torchrec_version" != "$fbgemm_version" ]; then
+ echo "Error: TorchRec package version does not match FBGEMM package version"
+ exit 1
+fi
+
+# check directory
+ls -R
+
+# check if cuda available
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())"
+
+# check cuda version
+conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)"
+
+# python 3.11 needs torchx-nightly
+conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath
+
+# Finally run smoke test
+conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py
diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml
new file mode 100644
index 000000000..a258658b3
--- /dev/null
+++ b/.github/workflows/build-wheels-linux.yml
@@ -0,0 +1,64 @@
+name: Build Linux Wheels
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ tags:
+ # Release candidate tags look like: v1.11.0-rc1
+ - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
+ workflow_dispatch:
+
+permissions:
+ id-token: write
+ contents: read
+
+jobs:
+ generate-matrix:
+ uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
+ with:
+ package-type: wheel
+ os: linux
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ with-rocm: false
+ filter-matrix:
+ needs: generate-matrix
+ runs-on: linux.24_04.4x
+ outputs:
+ matrix: ${{ steps.filter.outputs.matrix }}
+ steps:
+ - uses: actions/setup-python@v4
+ - name: Checkout torchrec repository
+ uses: actions/checkout@v4
+ with:
+ repository: pytorch/torchrec
+ - name: Filter Generated Built Matrix
+ id: filter
+ env:
+ MAT: ${{ needs.generate-matrix.outputs.matrix }}
+ run: |
+ set -ex
+ pwd
+ ls
+ MATRIX_BLOB="$(python .github/scripts/filter.py)"
+ echo "${MATRIX_BLOB}"
+ echo "matrix=${MATRIX_BLOB}" >> "${GITHUB_OUTPUT}"
+ build:
+ needs: filter-matrix
+ name: pytorch/torchrec
+ uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
+ with:
+ repository: pytorch/torchrec
+ ref: ""
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
+ pre-script: ""
+ post-script: .github/scripts/install_libs.sh
+ package-name: torchrec
+ smoke-test-script: ""
+ trigger-event: ${{ github.event_name }}
diff --git a/.github/workflows/build_dynamic_embedding_wheels.yml b/.github/workflows/build_dynamic_embedding_wheels.yml
index cdf4d5f76..837d7ec52 100644
--- a/.github/workflows/build_dynamic_embedding_wheels.yml
+++ b/.github/workflows/build_dynamic_embedding_wheels.yml
@@ -3,6 +3,8 @@ name: Build Dynamic Embedding Wheels
on:
workflow_dispatch:
pull_request:
+ paths:
+ - "contrib/*"
branches:
- main
push:
@@ -20,27 +22,52 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
- pyver: [ cp37, cp38, cp39, cp310 ]
- cuver: [ "11.6", "11.3"]
+ pyver: [ cp39, cp310, cp311, cp312 ]
+ cuver: [ "12.1", "12.4"]
steps:
- - uses: actions/checkout@v3
+ -
+ name: Check disk space
+ run: df . -h
+
+ - name: Remove unnecessary files
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+
+ -
+ name: Check disk space
+ run: df . -h
+
+ - uses: actions/checkout@v4
with:
submodules: recursive
- - uses: pypa/cibuildwheel@v2.8.0
- with:
+ - uses: pypa/cibuildwheel@v2.20.0
+ with:
package-dir: contrib/dynamic_embedding
env:
CIBW_BEFORE_BUILD: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/before_linux_build.sh"
CIBW_BUILD: "${{ matrix.pyver }}-manylinux_x86_64"
CIBW_REPAIR_WHEEL_COMMAND: "env CUDA_VERSION=${{ matrix.cuver }} contrib/dynamic_embedding/tools/repair_wheel.sh {wheel} {dest_dir}"
+ CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28"
- name: Verify clean directory
run: git diff --exit-code
shell: bash
- name: Upload wheels
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
+ name: artifact-${{ matrix.os }}-${{ matrix.pyver }}-cu${{ matrix.cuver }}
path: wheelhouse/*.whl
+
+ merge:
+ runs-on: ubuntu-latest
+ needs: build_wheels
+ steps:
+ - name: Merge Artifacts
+ uses: actions/upload-artifact/merge@v4
+ with:
+ name: artifact
+ pattern: artifact-*
diff --git a/.github/workflows/cpp_unittest_ci_cpu.yml b/.github/workflows/cpp_unittest_ci_cpu.yml
new file mode 100644
index 000000000..ad9e6e13f
--- /dev/null
+++ b/.github/workflows/cpp_unittest_ci_cpu.yml
@@ -0,0 +1,63 @@
+# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
+# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: CPU Unit Test C++ CI
+
+on:
+ push:
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
+ pull_request:
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
+
+jobs:
+ build_test:
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - os: linux.2xlarge
+ python-version: 3.9
+ python-tag: "py39"
+ - os: linux.2xlarge
+ python-version: '3.10'
+ python-tag: "py310"
+ - os: linux.2xlarge
+ python-version: '3.11'
+ python-tag: "py311"
+ - os: linux.2xlarge
+ python-version: '3.12'
+ python-tag: "py312"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
+ with:
+ runner: ${{ matrix.os }}
+ timeout: 15
+ script: |
+ ldd --version
+ conda create -y --name build_binary python=${{ matrix.python-version }}
+ conda info
+ python --version
+ echo "Starting C++ Tests"
+ conda install -n build_binary -y gxx_linux-64
+ conda run -n build_binary \
+ x86_64-conda-linux-gnu-g++ --version
+ conda install -n build_binary -c anaconda redis -y
+ conda run -n build_binary redis-server --daemonize yes
+ mkdir cpp-build
+ cd cpp-build
+ conda run -n build_binary cmake \
+ -DBUILD_TEST=ON \
+ -DBUILD_REDIS_IO=ON \
+ -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake ..
+ conda run -n build_binary make -j
+ conda run -n build_binary ctest -V .
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index c69be055e..547fd0d77 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -5,25 +5,32 @@ on:
branches:
- main
workflow_dispatch:
+ pull_request:
+
jobs:
build_docs_job:
runs-on: ${{ matrix.os }}
+ permissions:
+ # Grant write permission here so that the doc can be pushed to gh-pages branch
+ contents: write
strategy:
matrix:
include:
- - os: linux.2xlarge
- python-version: 3.7
+ - os: linux.24_04.4x
+ python-version: 3.9
+ python-tag: "py39"
steps:
- name: Check ldd --version
run: ldd --version
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
# Update references
- name: Update pip
run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
+ sudo apt-get update
+ sudo apt-get -y install python3-pip
+ sudo apt upgrade python3-pip
+ pip --version
- name: Setup conda
run: |
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
@@ -45,17 +52,24 @@ jobs:
- name: Install gcc
shell: bash
run: |
- sudo yum group install -y "Development Tools"
+ sudo apt-get install build-essential
- name: setup Path
run: |
echo /usr/local/bin >> $GITHUB_PATH
- name: Install PyTorch
shell: bash
run: |
- conda run -n build_binary python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
+ conda run -n build_binary pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+ - name: Install fbgemm
+ run: |
+ conda run -n build_binary pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
+ - name: Install torchmetrics
+ run: |
+ conda run -n build_binary pip install torchmetrics==1.0.3
- name: Install TorchRec
run: |
- conda run -n build_binary python -m pip install --pre torchrec_nightly_cpu -f https://download.pytorch.org/whl/nightly/torchrec_nightly_cpu/index.html
+ conda run -n build_binary pip install -r requirements.txt
+ conda run -n build_binary python setup.py bdist_wheel --python-tag=${{ matrix.python-tag }}
- name: Test fbgemm_gpu and torchrec installation
shell: bash
run: |
@@ -69,11 +83,42 @@ jobs:
cd ./docs
conda run -n build_binary make html
cd ..
+ - name: Upload Built-Docs
+ uses: actions/upload-artifact@v4
+ with:
+ name: Built-Docs
+ path: docs/build/html/
- name: Get output time
run: echo "The time was ${{ steps.build.outputs.time }}"
- name: Deploy
+ if: github.ref == 'refs/heads/main'
uses: JamesIves/github-pages-deploy-action@releases/v3
with:
ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH: gh-pages # The branch the action should deploy to.
FOLDER: docs/build/html # The folder the action should deploy.
+
+ doc-preview:
+ runs-on: [linux.2xlarge]
+ needs: build_docs_job
+ if: ${{ github.event_name == 'pull_request' }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Download artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: Built-Docs
+ path: docs
+ - name: Add no-index tag
+ run: |
+ find docs -name "*.html" -print0 | xargs -0 sed -i '/
/a \ \ ';
+ - name: Upload docs preview
+ uses: seemethere/upload-artifact-s3@v5
+ if: ${{ github.event_name == 'pull_request' }}
+ with:
+ retention-days: 14
+ s3-bucket: doc-previews
+ if-no-files-found: error
+ path: docs
+ s3-prefix: pytorch/torchrec/${{ github.event.pull_request.number }}
diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml
deleted file mode 100644
index be58565f1..000000000
--- a/.github/workflows/nightly_build.yml
+++ /dev/null
@@ -1,207 +0,0 @@
-# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
-# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
-
-name: Push Binary Nightly
-
-on:
- workflow_call:
- secrets:
- AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID:
- required: true
- AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY:
- required: true
- PYPI_TOKEN:
- required: false
- # run every day at 11:15am
- schedule:
- - cron: '15 11 * * *'
- # or manually trigger it
- workflow_dispatch:
-
-jobs:
- # build on cpu hosts and upload to GHA
- build_on_cpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.9
- python-tag: "py39"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: '3.10'
- python-tag: "py310"
- cuda-tag: "cu11"
- steps:
- # Checkout the repository to the GitHub Actions runner
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda -u
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- - name: Install Dependencies
- shell: bash
- run: |
- conda run -n build_binary python -m pip install -r requirements.txt
- - name: Test Installation of dependencies
- run: |
- conda run -n build_binary python -c "import torch.distributed"
- echo "torch.distributed succeeded"
- conda run -n build_binary python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary python -c "import numpy"
- echo "numpy succeeded"
- # for the conda run with quotes, we have to use "\" and double quotes
- # here is the issue: https://github.com/conda/conda/issues/10972
- - name: Build TorchRec Nightly
- run: |
- rm -r dist || true
- conda run -n build_binary \
- python setup.py bdist_wheel \
- --package_name torchrec-nightly \
- --python-tag=${{ matrix.python-tag }}
- - name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
- with:
- name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- path: dist/torchrec_nightly-*.whl
-
- # download from GHA, test on gpu and push to pypi
- test_on_gpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7, 3.8, 3.9]
- cuda-tag: ["cu11"]
- needs: build_on_cpu
- # the glibc version should match the version of the one we used to build the binary
- # for this case, it's 2.26
- steps:
- - name: Check ldd --version
- run: ldd --version
- - name: check cpu info
- shell: bash
- run: |
- cat /proc/cpuinfo
- - name: check distribution info
- shell: bash
- run: |
- cat /proc/version
- - name: Display EC2 information
- shell: bash
- run: |
- set -euo pipefail
- function get_ec2_metadata() {
- # Pulled from instance metadata endpoint for EC2
- # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
- category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
- }
- echo "ami-id: $(get_ec2_metadata ami-id)"
- echo "instance-id: $(get_ec2_metadata instance-id)"
- echo "instance-type: $(get_ec2_metadata instance-type)"
- - name: check gpu info
- shell: bash
- run: |
- sudo yum install lshw -y
- sudo lshw -C display
- # Checkout the repository to the GitHub Actions runner
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- # download wheel from GHA
- - name: Download wheel
- uses: actions/download-artifact@v2
- with:
- name: torchrec_nightly_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- - name: Display structure of downloaded files
- run: ls -R
- - name: Install TorchRec Nightly
- run: |
- rm -r dist || true
- conda run -n build_binary python -m pip install *.whl
- - name: Test torchrec installation
- shell: bash
- run: |
- conda run -n build_binary \
- python -c "import torchrec"
- - name: Push TorchRec Binary to PYPI
- env:
- PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
- run: |
- conda run -n build_binary python -m pip install twine
- conda run -n build_binary \
- python -m twine upload \
- --username __token__ \
- --password "$PYPI_TOKEN" \
- --skip-existing \
- torchrec_nightly-*.whl \
- --verbose
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
index eff5ce19b..74a5af78c 100644
--- a/.github/workflows/pre-commit.yaml
+++ b/.github/workflows/pre-commit.yaml
@@ -10,15 +10,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Setup Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
- python-version: 3.8
+ python-version: 3.9
architecture: x64
packages: |
- ufmt==1.3.2
- black==22.3.0
- usort==1.0.2
+ ufmt==2.5.1
+ black==24.2.0
+ usort==1.0.8
- name: Checkout Torchrec
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Run pre-commit
- uses: pre-commit/action@v2.0.3
+ uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml
new file mode 100644
index 000000000..fd773c787
--- /dev/null
+++ b/.github/workflows/pyre.yml
@@ -0,0 +1,27 @@
+name: Pyre Check
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+
+jobs:
+ pyre-check:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -el {0}
+ steps:
+ - uses: conda-incubator/setup-miniconda@v2
+ with:
+ python-version: 3.9
+ - name: Checkout Torchrec
+ uses: actions/checkout@v4
+ - name: Install dependencies
+ run: >
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu &&
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu &&
+ pip install -r requirements.txt &&
+ pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g')
+ - name: Pyre check
+ run: pyre check
diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml
index 58c7c55fb..7420e5be1 100644
--- a/.github/workflows/release_build.yml
+++ b/.github/workflows/release_build.yml
@@ -6,14 +6,11 @@ name: Push Binary Release
on:
workflow_call:
secrets:
- AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID:
- required: true
- AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY:
- required: true
PYPI_TOKEN:
required: false
workflow_dispatch:
+
jobs:
# build on cpu hosts and upload to GHA
@@ -22,189 +19,209 @@ jobs:
strategy:
matrix:
include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: 3.9
- python-tag: "py39"
- cuda-tag: "cu11"
- - os: linux.2xlarge
- python-version: '3.10'
- python-tag: "py310"
- cuda-tag: "cu11"
+ - os: linux.2xlarge
+ python-version: 3.9
+ python-tag: "py39"
+ cuda-tag: "cu126"
+ - os: linux.2xlarge
+ python-version: '3.10'
+ python-tag: "py310"
+ cuda-tag: "cu126"
+ - os: linux.2xlarge
+ python-version: '3.11'
+ python-tag: "py311"
+ cuda-tag: "cu126"
+ - os: linux.2xlarge
+ python-version: '3.12'
+ python-tag: "py312"
+ cuda-tag: "cu126"
+ - os: linux.2xlarge
+ python-version: '3.13'
+ python-tag: "py313"
+ cuda-tag: "cu126"
steps:
- # Checkout the repository to the GitHub Actions runner
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda -u
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia
- - name: Install Dependencies
- shell: bash
- run: |
- conda run -n build_binary python -m pip install -r requirements.txt
- - name: Test Installation of dependencies
- run: |
- conda run -n build_binary python -c "import torch.distributed"
- echo "torch.distributed succeeded"
- conda run -n build_binary python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary python -c "import numpy"
- echo "numpy succeeded"
- # for the conda run with quotes, we have to use "\" and double quotes
- # here is the issue: https://github.com/conda/conda/issues/10972
- - name: Build TorchRec
- run: |
- rm -r dist || true
- conda run -n build_binary \
- python setup.py bdist_wheel \
- --package_name torchrec \
- --python-tag=${{ matrix.python-tag }}
- - name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
- with:
- name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- path: dist/torchrec-*.whl
+ # Checkout the repository to the GitHub Actions runner
+ - name: Check ldd --version
+ run: ldd --version
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Update pip
+ run: |
+ sudo yum update -y
+ sudo yum -y install git python3-pip
+ - name: Setup conda
+ run: |
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
+ bash ~/miniconda.sh -b -p $HOME/miniconda -u
+ - name: setup Path
+ run: |
+ echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
+ echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
+ - name: create conda env
+ run: |
+ conda create --name build_binary python=${{ matrix.python-version }}
+ conda info
+ - name: check python version no Conda
+ run: |
+ python --version
+ - name: check python version
+ run: |
+ conda run -n build_binary python --version
+ - name: Install C/C++ compilers
+ run: |
+ sudo yum install -y gcc gcc-c++
+ - name: Install PyTorch and CUDA
+ shell: bash
+ run: |
+ conda run -n build_binary pip install torch
+ - name: Install fbgemm
+ shell: bash
+ run: |
+ conda run -n build_binary pip install numpy
+ conda run -n build_binary pip install fbgemm-gpu
+ - name: Install Dependencies
+ shell: bash
+ run: |
+ conda run -n build_binary python -m pip install -r requirements.txt
+ - name: Test Installation of dependencies
+ run: |
+ conda run -n build_binary python -c "import torch.distributed"
+ echo "torch.distributed succeeded"
+ conda run -n build_binary python -c "import skbuild"
+ echo "skbuild succeeded"
+ conda run -n build_binary python -c "import numpy"
+ echo "numpy succeeded"
+ # for the conda run with quotes, we have to use "\" and double quotes
+ # here is the issue: https://github.com/conda/conda/issues/10972
+ - name: Build TorchRec
+ env:
+ OFFICIAL_RELEASE: 1
+ run: |
+ rm -r dist || true
+ conda run -n build_binary \
+ python setup.py bdist_wheel \
+ --python-tag=${{ matrix.python-tag }}
+ - name: Upload wheel as GHA artifact
+ uses: actions/upload-artifact@v4
+ with:
+ name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
+ path: dist/torchrec-*.whl
- # download from GHA, test on gpu and push to pypi
- test_on_gpu:
+ # download from GHA, sanity check on gpu and push to pypi
+ sanity_check_on_gpu_and_push:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7, 3.8, 3.9]
- cuda-tag: ["cu11"]
+ os:
+ - linux.g5.12xlarge.nvidia.gpu
+ python-version:
+ - 3.9
+ - "3.10"
+ - "3.11"
+ - "3.12"
+ - "3.13"
+ cuda-tag:
+ - "cu126"
needs: build_on_cpu
# the glibc version should match the version of the one we used to build the binary
# for this case, it's 2.26
steps:
- - name: Check ldd --version
- run: ldd --version
- - name: check cpu info
- shell: bash
- run: |
- cat /proc/cpuinfo
- - name: check distribution info
- shell: bash
- run: |
- cat /proc/version
- - name: Display EC2 information
- shell: bash
- run: |
- set -euo pipefail
- function get_ec2_metadata() {
- # Pulled from instance metadata endpoint for EC2
- # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
- category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
- }
- echo "ami-id: $(get_ec2_metadata ami-id)"
- echo "instance-id: $(get_ec2_metadata instance-id)"
- echo "instance-type: $(get_ec2_metadata instance-type)"
- - name: check gpu info
- shell: bash
- run: |
- sudo yum install lshw -y
- sudo lshw -C display
- # Checkout the repository to the GitHub Actions runner
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-test -c nvidia
- # download wheel from GHA
- - name: Download wheel
- uses: actions/download-artifact@v2
- with:
- name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- - name: Display structure of downloaded files
- run: ls -R
- - name: Install TorchRec
- run: |
- rm -r dist || true
- conda run -n build_binary python -m pip install *.whl
- - name: Test fbgemm_gpu and torchrec installation
- shell: bash
- run: |
- conda run -n build_binary \
- python -c "import torchrec"
- - name: Test with pytest
- run: |
- conda run -n build_binary \
- python -m pip install pytest
- conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors
- # Push to Pypi
- - name: Push TorchRec Binary to PYPI
- env:
+ - name: Check ldd --version
+ run: ldd --version
+ - name: check cpu info
+ shell: bash
+ run: |
+ cat /proc/cpuinfo
+ - name: check distribution info
+ shell: bash
+ run: |
+ cat /proc/version
+ - name: Display EC2 information
+ shell: bash
+ run: |
+ set -euo pipefail
+ function get_ec2_metadata() {
+ # Pulled from instance metadata endpoint for EC2
+ # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
+ category=$1
+ curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
+ }
+ echo "ami-id: $(get_ec2_metadata ami-id)"
+ echo "instance-id: $(get_ec2_metadata instance-id)"
+ echo "instance-type: $(get_ec2_metadata instance-type)"
+ - name: check gpu info
+ shell: bash
+ run: |
+ sudo yum install lshw -y
+ sudo lshw -C display
+ # Checkout the repository to the GitHub Actions runner
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Update pip
+ run: |
+ sudo yum update -y
+ sudo yum -y install git python3-pip
+ - name: Setup conda
+ run: |
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
+ bash ~/miniconda.sh -b -p $HOME/miniconda
+ - name: setup Path
+ run: |
+ echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
+ echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
+ - name: create conda env
+ run: |
+ conda create --name build_binary python=${{ matrix.python-version }}
+ conda info
+ - name: check python version no Conda
+ run: |
+ python --version
+ - name: check python version
+ run: |
+ conda run -n build_binary python --version
+ - name: Install C/C++ compilers
+ run: |
+ sudo yum install -y gcc gcc-c++
+ - name: Install PyTorch and CUDA
+ shell: bash
+ run: |
+ conda run -n build_binary pip install torch
+ # download wheel from GHA
+ - name: Install fbgemm
+ shell: bash
+ run: |
+ conda run -n build_binary pip install numpy
+ conda run -n build_binary pip install fbgemm-gpu
+ - name: Install torchmetrics
+ shell: bash
+ run: |
+ conda run -n build_binary pip install torchmetrics==1.0.3
+ - name: Download wheel
+ uses: actions/download-artifact@v4
+ with:
+ name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
+ - name: Display structure of downloaded files
+ run: ls -R
+ - name: Install TorchRec
+ run: |
+ rm -r dist || true
+ conda run -n build_binary python -m pip install *.whl
+ - name: Test fbgemm_gpu and torchrec installation
+ shell: bash
+ run: |
+ conda run -n build_binary \
+ python -c "import fbgemm_gpu"
+ conda run -n build_binary \
+ python -c "import torchrec"
+ # Push to Pypi
+ - name: Push TorchRec Binary to PYPI
+ env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
- run: |
- conda run -n build_binary python -m pip install twine
- conda run -n build_binary \
- python -m twine upload \
- --username __token__ \
- --password "$PYPI_TOKEN" \
- --skip-existing \
- torchrec-*.whl
+ run: |
+ conda run -n build_binary python -m pip install twine
+ conda run -n build_binary \
+ python -m twine upload \
+ --username __token__ \
+ --password "$PYPI_TOKEN" \
+ --skip-existing \
+ torchrec-*.whl
diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml
index 82cbf5621..5fdce0115 100644
--- a/.github/workflows/unittest_ci.yml
+++ b/.github/workflows/unittest_ci.yml
@@ -4,188 +4,119 @@
name: Unit Test CI
on:
- # TODO: re-enable when GPU unit tests are working
- # push:
- # paths-ignore:
- # - "docs/*"
- # - "third_party/*"
- # - .gitignore
- # - "*.md"
- # pull_request:
- # paths-ignore:
- # - "docs/*"
- # - "third_party/*"
- # - .gitignore
- # - "*.md"
+ push:
+ branches:
+ - nightly
+ - main
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
+ pull_request:
+ paths-ignore:
+ - "docs/*"
+ - "third_party/*"
+ - .gitignore
+ - "*.md"
workflow_dispatch:
jobs:
- # build on cpu hosts and upload to GHA
- build_on_cpu:
- runs-on: ${{ matrix.os }}
+ build_test:
strategy:
+ fail-fast: false
matrix:
- include:
- - os: linux.2xlarge
- # ideally we run on 3.8 and 3.9 as well, however we are limited in resources.
- python-version: 3.7
- python-tag: "py37"
- cuda-tag: "cu11"
- steps:
- # Checkout the repository to the GitHub Actions runner
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda -u
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
+ cuda-tag: ["cu118", "cu126", "cu128"]
+ os:
+ - linux.g5.12xlarge.nvidia.gpu
+ python:
+ - version: "3.9"
+ tag: "py39"
+ - version: "3.10"
+ tag: "py310"
+ - version: "3.11"
+ tag: "py311"
+ - version: "3.12"
+ tag: "py312"
+ - version: "3.13"
+ tag: "py313"
+ is_pr:
+ - ${{ github.event_name == 'pull_request' }}
+ exclude:
+ - is_pr: true
+ cuda-tag: "cu118"
+ - is_pr: true
+ cuda-tag: "cu126"
+ - is_pr: true
+ cuda-tag: "cu128"
+ python:
+ version: "3.9"
+ - is_pr: true
+ cuda-tag: "cu128"
+ python:
+ version: "3.10"
+ - is_pr: true
+ cuda-tag: "cu128"
+ python:
+ version: "3.11"
+ - is_pr: true
+ cuda-tag: "cu128"
+ python:
+ version: "3.12"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
+ with:
+ runner: ${{ matrix.os }}
+ timeout: 60
+ script: |
+ ldd --version
+ conda create -y --name build_binary python=${{ matrix.python.version }}
conda info
- - name: check python version no Conda
- run: |
python --version
- - name: check python version
- run: |
conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- - name: Install Dependencies
- shell: bash
- run: |
- conda run -n build_binary python -m pip install -r requirements.txt
- - name: Test Installation of dependencies
- run: |
- conda run -n build_binary python -c "import torch.distributed"
- echo "torch.distributed succeeded"
- conda run -n build_binary python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary python -c "import numpy"
- echo "numpy succeeded"
- # for the conda run with quotes, we have to use "\" and double quotes
- # here is the issue: https://github.com/conda/conda/issues/10972
- - name: Build TorchRec Binary
- run: |
+ conda run -n build_binary \
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
+ conda run -n build_binary \
+ python -c "import torch"
+ echo "torch succeeded"
+ conda run -n build_binary \
+ python -c "import torch.distributed"
+ conda run -n build_binary \
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
+ conda run -n build_binary \
+ python -c "import fbgemm_gpu"
+ echo "fbgemm_gpu succeeded"
+ conda run -n build_binary \
+ pip install -r requirements.txt
conda run -n build_binary \
python setup.py bdist_wheel \
- --package_name torchrec-test \
- --python-tag=${{ matrix.python-tag }}
- - name: Upload wheel as GHA artifact
- uses: actions/upload-artifact@v2
- with:
- name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- path: dist/torchrec-test-*.whl
-
- # download from GHA, test on gpu
- test_on_gpu:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [linux.4xlarge.nvidia.gpu]
- python-version: [3.7]
- cuda-tag: ["cu11"]
- needs: build_on_cpu
- # the glibc version should match the version of the one we used to build the binary
- # for this case, it's 2.26
- steps:
- - name: Check ldd --version
- # Run unit tests
- run: ldd --version
- - name: check cpu info
- shell: bash
- run: |
- cat /proc/cpuinfo
- - name: check distribution info
- shell: bash
- run: |
- cat /proc/version
- - name: Display EC2 information
- shell: bash
- run: |
- set -euo pipefail
- function get_ec2_metadata() {
- # Pulled from instance metadata endpoint for EC2
- # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
- category=$1
- curl -fsSL "http://169.254.169.254/latest/meta-data/$%7Bcategory%7D"
- }
- echo "ami-id: $(get_ec2_metadata ami-id)"
- echo "instance-id: $(get_ec2_metadata instance-id)"
- echo "instance-type: $(get_ec2_metadata instance-type)"
- - name: check gpu info
- shell: bash
- run: |
- sudo yum install lshw -y
- sudo lshw -C display
- # Checkout the repository to the GitHub Actions runner
- - name: Checkout
- uses: actions/checkout@v2
- - name: Update pip
- run: |
- sudo yum update -y
- sudo yum -y install git python3-pip
- sudo pip3 install --upgrade pip
- - name: Setup conda
- run: |
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
- bash ~/miniconda.sh -b -p $HOME/miniconda
- - name: setup Path
- run: |
- echo "/home/ec2-user/miniconda/bin" >> $GITHUB_PATH
- echo "CONDA=/home/ec2-user/miniconda" >> $GITHUB_PATH
- - name: create conda env
- run: |
- conda create --name build_binary python=${{ matrix.python-version }}
- conda info
- - name: check python version no Conda
- run: |
- python --version
- - name: check python version
- run: |
- conda run -n build_binary python --version
- - name: Install C/C++ compilers
- run: |
- sudo yum install -y gcc gcc-c++
- - name: Install PyTorch and CUDA
- shell: bash
- run: |
- conda install -n build_binary -y pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
- # download wheel from GHA
- - name: Download wheel
- uses: actions/download-artifact@v2
- with:
- name: torchrec-test_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl
- - name: Display structure of downloaded files
- run: ls -R
- - name: Install TorchRec GPU
- run: |
- rm -r dist || true
- conda run -n build_binary python -m pip install dist/*.whl
- - name: Test torchrec installation
- shell: bash
- run: |
+ --python-tag=${{ matrix.python.tag }}
conda run -n build_binary \
python -c "import torchrec"
- - name: Test with pytest
- run: |
+ echo "torch.distributed succeeded"
conda run -n build_binary \
- python -m pip install pytest
+ python -c "import numpy"
+ echo "numpy succeeded"
+ conda install -n build_binary -y pytest
+ # Read the list of tests to skip from a file, ignoring empty lines and comments
+ skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
+ # Check if skip_expression is effectively empty
+ if [ -z "$skip_expression" ]; then
+ skip_expression=""
+ else
+ skip_expression=${skip_expression:5} # Remove the leading " and "
+ fi
conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors
+ python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
+ --ignore=torchrec/distributed/tests/test_comm.py --ignore=torchrec/distributed/tests/test_infer_shardings.py \
+ --ignore=torchrec/distributed/tests/test_keyed_jagged_tensor_pool.py --ignore=torchrec/distributed/tests/test_pt2_multiprocess.py \
+ --ignore=torchrec/distributed/tests/test_pt2.py --ignore=torchrec/distributed/tests/test_quant_model_parallel.py \
+ --ignore=torchrec/distributed/tests/test_quant_pruning.py --ignore=torchrec/distributed/tests/test_quant_sequence_model_parallel.py \
+ --ignore-glob='torchrec/metrics/*' --ignore-glob='torchrec/distributed/tests/test_model_parallel_gloo*' \
+ --ignore-glob='torchrec/inference/inference_legacy/tests*' --ignore-glob='*test_model_parallel_nccl*' \
+ --ignore=torchrec/distributed/tests/test_cache_prefetch.py --ignore=torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py \
+ --ignore=torchrec/distributed/tests/test_infer_utils.py --ignore=torchrec/distributed/tests/test_fx_jit.py --ignore-glob=**/test_utils/ \
+ --ignore-glob='*test_train_pipeline*' --ignore=torchrec/distributed/tests/test_model_parallel_hierarchical.py \
+ -k "$skip_expression"
diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml
index e69c4f72f..7289fa509 100644
--- a/.github/workflows/unittest_ci_cpu.yml
+++ b/.github/workflows/unittest_ci_cpu.yml
@@ -20,70 +20,78 @@ on:
jobs:
build_test:
strategy:
- matrix:
- include:
- - os: linux.2xlarge
- python-version: 3.7
- python-tag: "py37"
- - os: linux.2xlarge
- python-version: 3.8
- python-tag: "py38"
- - os: linux.2xlarge
- python-version: 3.9
- python-tag: "py39"
- - os: linux.2xlarge
- python-version: '3.10'
- python-tag: "py310"
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ fail-fast: false
+ matrix:
+ os:
+ - linux.2xlarge
+ python:
+ - version: "3.9"
+ tag: "py39"
+ - version: "3.10"
+ tag: "py310"
+ - version: "3.11"
+ tag: "py311"
+ - version: "3.12"
+ tag: "py312"
+ - version: "3.13"
+ tag: "py313"
+ is_pr:
+ - ${{ github.event_name == 'pull_request' }}
+ exclude:
+ - is_pr: true
+ python:
+ version: "3.10"
+ - is_pr: true
+ python:
+ version: "3.11"
+ - is_pr: true
+ python:
+ version: "3.12"
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: ${{ matrix.os }}
- timeout: 30
+ timeout: 15
script: |
ldd --version
- conda create -y --name build_binary python=${{ matrix.python-version }}
+ conda create -y --name build_binary python=${{ matrix.python.version }}
conda info
python --version
conda run -n build_binary python --version
- conda install -n build_binary \
- --yes \
- -c pytorch-nightly \
- "pytorch-nightly"::pytorch[build="*cpu*"]
conda run -n build_binary \
- pip install -r requirements.txt
- conda run -n build_binary \
- pip uninstall fbgemm_gpu-nightly -y
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
conda run -n build_binary \
- pip install fbgemm-gpu-nightly-cpu
+ python -c "import torch"
+ echo "torch succeeded"
conda run -n build_binary \
python -c "import torch.distributed"
- echo "torch.distributed succeeded"
conda run -n build_binary \
- python -c "import skbuild"
- echo "skbuild succeeded"
- conda run -n build_binary \
- python -c "import numpy"
- echo "numpy succeeded"
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
conda run -n build_binary \
python -c "import fbgemm_gpu"
echo "fbgemm_gpu succeeded"
-
+ conda run -n build_binary \
+ pip install -r requirements.txt
conda run -n build_binary \
python setup.py bdist_wheel \
- --package_name torchrec-test-cpu \
- --python-tag=${{ matrix.python-tag }}
+ --python-tag=${{ matrix.python.tag }}
conda run -n build_binary \
python -c "import torchrec"
- conda install -n build_binary -y pytest
+ echo "torch.distributed succeeded"
conda run -n build_binary \
- python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors -k 'not test_sharding_gloo_cw'
- echo "Starting C++ Tests"
- conda install -n build_binary -y gxx_linux-64
+ python -c "import numpy"
+ echo "numpy succeeded"
+ conda install -n build_binary -y pytest
+ # Read the list of tests to skip from a file, ignoring empty lines and comments
+ skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
+ # Check if skip_expression is effectively empty
+ if [ -z "$skip_expression" ]; then
+ skip_expression=""
+ else
+ skip_expression=${skip_expression:5} # Remove the leading " and "
+ fi
conda run -n build_binary \
- x86_64-conda-linux-gnu-g++ --version
- mkdir cpp-build
- cd cpp-build
- conda run -n build_binary cmake \
- -DBUILD_TEST=ON \
- -DCMAKE_PREFIX_PATH=/opt/conda/envs/build_binary/lib/python${{ matrix.python-version }}/site-packages/torch/share/cmake ..
- conda run -n build_binary make -j
- conda run -n build_binary ctest -v .
+ python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
+ --ignore-glob=**/test_utils/ -k "$skip_expression"
diff --git a/.github/workflows/validate-binaries.yml b/.github/workflows/validate-binaries.yml
new file mode 100644
index 000000000..075aeabdf
--- /dev/null
+++ b/.github/workflows/validate-binaries.yml
@@ -0,0 +1,42 @@
+name: Validate binaries
+
+on:
+ workflow_call:
+ inputs:
+ channel:
+ description: "Channel to use (nightly, release)"
+ required: false
+ type: string
+ default: release
+ ref:
+ description: 'Reference to checkout, defaults to empty'
+ default: ""
+ required: false
+ type: string
+ workflow_dispatch:
+ inputs:
+ channel:
+ description: "Channel to use (nightly, release, test, pypi)"
+ required: true
+ type: choice
+ options:
+ - release
+ - nightly
+ - test
+ ref:
+ description: 'Reference to checkout, defaults to empty'
+ default: ""
+ required: false
+ type: string
+
+jobs:
+ validate-binaries:
+ uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main
+ with:
+ package_type: "wheel"
+ os: "linux"
+ channel: ${{ inputs.channel }}
+ repository: "pytorch/torchrec"
+ smoke_test: "source ./.github/scripts/validate_binaries.sh"
+ with_cuda: enable
+ with_rocm: false
diff --git a/.github/workflows/validate-nightly-binaries.yml b/.github/workflows/validate-nightly-binaries.yml
new file mode 100644
index 000000000..0cc067912
--- /dev/null
+++ b/.github/workflows/validate-nightly-binaries.yml
@@ -0,0 +1,26 @@
+# Scheduled validation of the nightly binaries
+name: validate-nightly-binaries
+
+on:
+ schedule:
+ # At 5:30 pm UTC (7:30 am PDT)
+ - cron: "30 17 * * *"
+ # Have the ability to trigger this job manually through the API
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - '.github/workflows/validate-nightly-binaries.yml'
+ - '.github/workflows/validate-binaries.yml'
+ - '.github/scripts/validate-binaries.sh'
+ pull_request:
+ paths:
+ - '.github/workflows/validate-nightly-binaries.yml'
+ - '.github/workflows/validate-binaries.yml'
+ - '.github/scripts/validate-binaries.sh'
+jobs:
+ nightly:
+ uses: ./.github/workflows/validate-binaries.yml
+ with:
+ channel: nightly
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 23a76193a..8fd1bbc1f 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -11,7 +11,7 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
- 'black==22.3.0',
+ 'black==24.2.0',
]
is_formatter = true
@@ -28,6 +28,6 @@ init_command = [
'python3',
'tools/lint/pip_init.py',
'--dry-run={{DRYRUN}}',
- 'usort==1.0.2',
+ 'usort==1.0.8',
]
is_formatter = true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ea648c6ef..705694ea9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.0.1
+ rev: v4.5.0
hooks:
- id: check-toml
- id: check-yaml
@@ -8,9 +8,9 @@ repos:
- id: end-of-file-fixer
- repo: https://github.com/omnilib/ufmt
- rev: v1.3.2
+ rev: v2.5.1
hooks:
- id: ufmt
additional_dependencies:
- - black == 22.3.0
- - usort == 1.0.2
+ - black == 24.2.0
+ - usort == 1.0.8.post1
diff --git a/.pyre_configuration b/.pyre_configuration
new file mode 100644
index 000000000..4249f9c02
--- /dev/null
+++ b/.pyre_configuration
@@ -0,0 +1,17 @@
+{
+ "exclude": [
+ ".*/pyre-check/stubs/.*",
+ ".*/torchrec/datasets*",
+ ".*/torchrec/models*",
+ ".*/torchrec/inference/client.py"
+ ],
+ "site_package_search_strategy": "all",
+ "source_directories": [
+ {
+ "import_root": ".",
+ "source": "torchrec"
+ }
+ ],
+ "strict": true,
+ "version": "0.0.101729681899"
+}
diff --git a/README.MD b/README.MD
index 69c40d6ca..28915b79d 100644
--- a/README.MD
+++ b/README.MD
@@ -1,85 +1,98 @@
-# TorchRec (Beta Release)
-[Docs](https://pytorch.org/torchrec/)
+# TorchRec
-TorchRec is a PyTorch domain library built to provide common sparsity & parallelism primitives needed for large-scale recommender systems (RecSys). It allows authors to train models with large embedding tables sharded across many GPUs.
+**TorchRec** is a PyTorch domain library built to provide common sparsity and parallelism primitives needed for large-scale recommender systems (RecSys). TorchRec allows training and inference of models with large embedding tables sharded across many GPUs and **powers many production RecSys models at Meta**.
-## TorchRec contains:
+## External Presence
+TorchRec has been used to accelerate advancements in recommendation systems, some examples:
+* [Latest version of Meta's DLRM (Deep Learning Recommendation Model)](https://github.com/facebookresearch/dlrm) is built using TorchRec
+* [Disaggregated Multi-Tower: Topology-aware Modeling Technique for Efficient Large-Scale Recommendation](https://arxiv.org/abs/2403.00877) paper
+* [The Algorithm ML](https://github.com/twitter/the-algorithm-ml) from Twitter
+* [Training Recommendation Models with Databricks](https://docs.databricks.com/en/machine-learning/train-recommender-models.html)
+* [Toward 100TB model with Embedding Offloading Paper](https://dl.acm.org/doi/10.1145/3640457.3688037)
+
+
+## Introduction
+
+To begin learning about TorchRec, check out:
+* Our complete [TorchRec Tutorial](https://pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
+* The [TorchRec documentation](https://pytorch.org/torchrec/) for an overview of TorchRec and API references
+
+
+### TorchRec Features
- Parallelism primitives that enable easy authoring of large, performant multi-device/multi-node models using hybrid data-parallelism/model-parallelism.
-- The TorchRec sharder can shard embedding tables with different sharding strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, and column-wise sharding.
-- The TorchRec planner can automatically generate optimized sharding plans for models.
-- Pipelined training overlaps dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance.
-- Optimized kernels for RecSys powered by FBGEMM.
-- Quantization support for reduced precision training and inference.
+- Sharders to shard embedding tables with different strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, column-wise, and table-wise-column-wise sharding.
+- Planner that can automatically generate optimized sharding plans for models.
+- Pipelined training overlapping dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance.
+- Optimized kernels for RecSys powered by [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main).
+- Quantization support for reduced precision training and inference, along with optimizing a TorchRec model for C++ inference.
- Common modules for RecSys.
-- Production-proven model architectures for RecSys.
- RecSys datasets (criteo click logs and movielens)
- Examples of end-to-end training such the dlrm event prediction model trained on criteo click logs dataset.
-# Installation
-Torchrec requires Python >= 3.7 and CUDA >= 11.0 (CUDA is highly recommended for performance but not required). The example below shows how to install with CUDA 11.6. This setup assumes you have conda installed.
+## Installation
-## Binaries
+Check out the [Getting Started](https://pytorch.org/torchrec/setup-torchrec.html) section in the documentation for recommended ways to set up Torchrec.
-Experimental binary on Linux for Python 3.7, 3.8 and 3.9 can be installed via pip wheels
+### From Source
-### Installations
-```
-TO use the library without cuda, use the *-cpu fbgemm installations. However, this will be much slower than the CUDA variant.
+**Generally, there isn't a need to build from source**. For most use cases, follow the section above to set up TorchRec. However, to build from source and to get the latest changes, do the following:
-Nightly
+1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/).
+ ```
+ CUDA 12.8
-conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
-pip install torchrec_nightly
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128
-Stable
+ CUDA 12.6
-conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
-pip install torchrec
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu126
-If you have no CUDA device:
+ CUDA 11.8
-Nightly
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118
-pip uninstall fbgemm-gpu-nightly -y
-pip install fbgemm-gpu-nightly-cpu
+ CPU
-Stable
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+ ```
-pip uninstall fbgemm-gpu -y
-pip install fbgemm-gpu-cpu
+2. Clone TorchRec.
+ ```
+ git clone --recursive https://github.com/pytorch/torchrec
+ cd torchrec
+ ```
-```
+3. Install FBGEMM.
+ ```
+ CUDA 12.8
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu128
-### Colab example: introduction + install
-See our colab notebook for an introduction to torchrec which includes runnable installation.
- - [Tutorial Source](https://github.com/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb)
- - Open in [Google Colab](https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb)
+ CUDA 12.6
-## From Source
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126
-We are currently iterating on the setup experience. For now, we provide manual instructions on how to build from source. The example below shows how to install with CUDA 11.3. This setup assumes you have conda installed.
+ CUDA 11.8
-1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/)
- ```
- conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu118
+
+ CPU
+
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu
```
-2. Install Requirements
+4. Install other requirements.
```
pip install -r requirements.txt
```
-3. Download and install TorchRec.
+4. Install TorchRec.
```
- git clone --recursive https://github.com/pytorch/torchrec
-
- cd torchrec
python setup.py install develop
```
-4. Test the installation.
+5. Test the installation.
```
GPU mode
@@ -93,5 +106,32 @@ We are currently iterating on the setup experience. For now, we provide manual i
5. If you want to run a more complex example, please take a look at the torchrec [DLRM example](https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/dlrm_main.py).
+## Contributing
+
+See [CONTRIBUTING.md](https://github.com/pytorch/torchrec/blob/main/CONTRIBUTING.md) for details about contributing to TorchRec!
+
+## Citation
+
+If you're using TorchRec, please refer to BibTeX entry to cite this work:
+```
+@inproceedings{10.1145/3523227.3547387,
+author = {Ivchenko, Dmytro and Van Der Staay, Dennis and Taylor, Colin and Liu, Xing and Feng, Will and Kindi, Rahul and Sudarshan, Anirudh and Sefati, Shahin},
+title = {TorchRec: a PyTorch Domain Library for Recommendation Systems},
+year = {2022},
+isbn = {9781450392785},
+publisher = {Association for Computing Machinery},
+address = {New York, NY, USA},
+url = {https://doi.org/10.1145/3523227.3547387},
+doi = {10.1145/3523227.3547387},
+abstract = {Recommendation Systems (RecSys) comprise a large footprint of production-deployed AI today. The neural network-based recommender systems differ from deep learning models in other domains in using high-cardinality categorical sparse features that require large embedding tables to be trained. In this talk we introduce TorchRec, a PyTorch domain library for Recommendation Systems. This new library provides common sparsity and parallelism primitives, enabling researchers to build state-of-the-art personalization models and deploy them in production. In this talk we cover the building blocks of the TorchRec library including modeling primitives such as embedding bags and jagged tensors, optimized recommender system kernels powered by FBGEMM, a flexible sharder that supports a veriety of strategies for partitioning embedding tables, a planner that automatically generates optimized and performant sharding plans, support for GPU inference and common modeling modules for building recommender system models. TorchRec library is currently used to train large-scale recommender models at Meta. We will present how TorchRec helped Meta’s recommender system platform to transition from CPU asynchronous training to accelerator-based full-sync training.},
+booktitle = {Proceedings of the 16th ACM Conference on Recommender Systems},
+pages = {482–483},
+numpages = {2},
+keywords = {information retrieval, recommender systems},
+location = {Seattle, WA, USA},
+series = {RecSys '22}
+}
+```
+
## License
TorchRec is BSD licensed, as found in the [LICENSE](LICENSE) file.
diff --git a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb
new file mode 100644
index 000000000..015d216e4
--- /dev/null
+++ b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb
@@ -0,0 +1,2972 @@
+{
+ "metadata": {
+ "custom": {
+ "cells": [],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "background_execution": "on",
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "name": "Torchrec Introduction.ipynb",
+ "provenance": []
+ },
+ "fileHeader": "",
+ "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
+ "interpreter": {
+ "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
+ },
+ "isAdHoc": false,
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('torchrec': conda)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+ },
+ "indentAmount": 2,
+ "last_server_session_id": "e11f329f-b395-4702-9b33-449716ea422e",
+ "last_kernel_id": "b6fe1a08-1d4d-40cd-afe6-8352c4e42d25",
+ "last_base_url": "https://bento.edge.x2p.facebook.net/",
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "last_msg_id": "c02547e3-e4c072dc430f066c4d18479a_594",
+ "captumWidgetMessage": [],
+ "outputWidgetContext": [],
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "accelerator": "GPU",
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hBgIy9eYYx35",
+ "originalKey": "4766a371-bf6e-4342-98fb-16dde5255d73",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "## **Open Source Installation** (For Reference)\n",
+ "Requirements:\n",
+ "- python >= 3.9\n",
+ "\n",
+ "We highly recommend CUDA when using TorchRec. If using CUDA:\n",
+ "- cuda >= 11.8\n",
+ "\n",
+ "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sFYvP95xaAER",
+ "originalKey": "27d22c43-9299-46ec-94f2-28a880546fe3",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000131275,
+ "executionStopTime": 1726000131459,
+ "serverExecutionDuration": 2.2683702409267,
+ "requestMsgId": "27d22c43-9299-46ec-94f2-28a880546fe3"
+ },
+ "source": [
+ "# Install stable versions for best reliability\n",
+ "\n",
+ "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n",
+ "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n",
+ "!pip3 install torchmetrics==1.0.3\n",
+ "!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-4DFtQNDYao1",
+ "originalKey": "07e2a5ae-9ca2-45d7-af10-84d8e09ce91e",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "# Intro to TorchRec\n",
+ "\n",
+ "### Embeddings\n",
+ "When building recommendation systems, categorical features typically have massive cardinalities, posts, users, ads, etc.\n",
+ "\n",
+ "In order to represent these entities and model these relationships, **embeddings** are used. In machine learning, **embeddings are a vectors of real numbers in a high-dimensional space used to represent meaning in complex data like words, images, or users**.\n",
+ "\n",
+ "\n",
+ "### Embeddings in RecSys\n",
+ "\n",
+ "Now you might wonder, how are these embeddings generated in the first place? Well, embeddings are represented as individual rows in an **Embedding Table**, also referred to as embedding weights. The reason for this is that embeddings/embedding table weights are trained just like all of the other weights of the model via gradient descent!\n",
+ "\n",
+ "Embedding tables are simply a large matrix for storing embeddings, with two dimensions (B, N), where\n",
+ "* B is the number of embeddings stored by the table\n",
+ "* N is the number of dimensions per embedding (N-dimensional embedding).\n",
+ "\n",
+ "\n",
+ "The inputs to embedding tables represent embedding lookups to retrieve the embedding for a specific index/row. In recommendation systems, such as those used in Meta, unique IDs are not only used for specific users, but also across entites like posts and ads to serve as lookup indices to respective embedding tables!\n",
+ "\n",
+ "Embeddings are trained in RecSys through the following process:\n",
+ "1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are hashed to the total size of the embedding table to prevent issues when the ID > # of rows\n",
+ "2. Embeddings are then retrieved and **pooled, such as taking the sum or mean of the embeddings**. This is required as there can be a variable # of embeddings per example while the model expects consistent shapes.\n",
+ "3. The **embeddings are used in conjunction with the rest of the model to produce a prediction**, such as [Click-Through Rate (CTR)](https://support.google.com/google-ads/answer/2615875?hl=en) for an Ad.\n",
+ "4. The loss is calculated with the prediction and the label for an example, and **all weights of the model are updated through gradient descent and backpropogation, including the embedding weights** that were associated with the example.\n",
+ "\n",
+ "These embeddings are crucial for representing categorical features, such as users, posts, and ads, in order to capture relationships and make good recommendations. Meta AI's [Deep learning recommendation model](https://arxiv.org/abs/1906.00091) (DLRM) paper talks more about the technical details of using embedding tables in RecSys.\n",
+ "\n",
+ "This tutorial will introduce the concept of embeddings, showcase TorchRec specific modules/datatypes, and depict how distributed training works with TorchRec."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "48b50971-aeab-4754-8cff-986496689f43",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000131464,
+ "executionStopTime": 1726000133971,
+ "serverExecutionDuration": 2349.9959111214,
+ "requestMsgId": "48b50971-aeab-4754-8cff-986496689f43",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "1534047040582458"
+ },
+ "id": "AbeT4W9xcso9"
+ },
+ "source": [
+ "import torch"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4b510f99-840d-4986-b635-33c21af48cf4",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "bjuDdEqocso-"
+ },
+ "source": [
+ "## Embeddings in PyTorch\n",
+ "[`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html): Embedding table where forward pass returns the embeddings themselves as is.\n",
+ "\n",
+ "[`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html): Embedding table where forward pass returns embeddings that are then pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**\n",
+ "\n",
+ "In this section, we will go over a very brief introduction with doing embedding lookups through passing in indices into the table. Check out the links for each for more sophisticated use cases and experiments!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000133982,
+ "executionStopTime": 1726000134201,
+ "serverExecutionDuration": 31.60185739398,
+ "requestMsgId": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "933119035309629"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1X5C_Dnccso-",
+ "outputId": "616cc153-67ee-4dd6-b1ab-ee6ff6f44709"
+ },
+ "source": [
+ "num_embeddings, embedding_dim = 10, 4\n",
+ "\n",
+ "# Initialize our embedding table\n",
+ "weights = torch.rand(num_embeddings, embedding_dim)\n",
+ "print(\"Weights:\", weights)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Weights: tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134203,
+ "executionStopTime": 1726000134366,
+ "serverExecutionDuration": 8.956927806139,
+ "requestMsgId": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "831419729143778"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "bxszzeGdcso-",
+ "outputId": "88f21deb-c2f1-4894-975e-fdac41436b36"
+ },
+ "source": [
+ "# Pass in pre generated weights just for example, typically weights are randomly initialized\n",
+ "embedding_collection = torch.nn.Embedding(\n",
+ " num_embeddings, embedding_dim, _weight=weights\n",
+ ")\n",
+ "embedding_bag_collection = torch.nn.EmbeddingBag(\n",
+ " num_embeddings, embedding_dim, _weight=weights\n",
+ ")\n",
+ "\n",
+ "# Print out the tables, we should see the same weights as above\n",
+ "print(\"Embedding Collection Table: \", embedding_collection.weight)\n",
+ "print(\"Embedding Bag Collection Table: \", embedding_bag_collection.weight)\n",
+ "\n",
+ "# Lookup rows (ids for embedding ids) from the embedding tables\n",
+ "# 2D tensor with shape (batch_size, ids for each batch)\n",
+ "ids = torch.tensor([[1, 3]])\n",
+ "print(\"Input row IDS: \", ids)"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Collection Table: Parameter containing:\n",
+ "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
+ "Embedding Bag Collection Table: Parameter containing:\n",
+ "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
+ " [0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7328, 0.0531, 0.9528, 0.0592],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401],\n",
+ " [0.4837, 0.2052, 0.3360, 0.9656],\n",
+ " [0.7887, 0.3066, 0.0956, 0.3344],\n",
+ " [0.5904, 0.8541, 0.5963, 0.2800],\n",
+ " [0.5751, 0.4341, 0.6218, 0.4101],\n",
+ " [0.6881, 0.5363, 0.4747, 0.2301],\n",
+ " [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
+ "Input row IDS: tensor([[1, 3]])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "cb5c5906-e9a6-4315-b860-b263e08989be",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134369,
+ "executionStopTime": 1726000134545,
+ "serverExecutionDuration": 5.9817284345627,
+ "requestMsgId": "cb5c5906-e9a6-4315-b860-b263e08989be",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "2201664893536578"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xkedJeTOcso_",
+ "outputId": "46215f2b-03ad-421b-f873-78b2be0df4d4"
+ },
+ "source": [
+ "embeddings = embedding_collection(ids)\n",
+ "\n",
+ "# Print out the embedding lookups\n",
+ "# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above\n",
+ "print(\"Embedding Collection Results: \")\n",
+ "print(embeddings)\n",
+ "print(\"Shape: \", embeddings.shape)"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Collection Results: \n",
+ "tensor([[[0.1830, 0.0326, 0.8241, 0.2995],\n",
+ " [0.7800, 0.1797, 0.0167, 0.7401]]], grad_fn=)\n",
+ "Shape: torch.Size([1, 2, 4])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134547,
+ "executionStopTime": 1726000134718,
+ "serverExecutionDuration": 7.8675262629986,
+ "requestMsgId": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "6449977515126116"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PmtJkxLccso_",
+ "outputId": "c33109b4-205e-492d-e0ca-94d2887ec6e7"
+ },
+ "source": [
+ "# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above\n",
+ "pooled_embeddings = embedding_bag_collection(ids)\n",
+ "\n",
+ "print(\"Embedding Bag Collection Results: \")\n",
+ "print(pooled_embeddings)\n",
+ "print(\"Shape: \", pooled_embeddings.shape)\n",
+ "\n",
+ "# nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc.)\n",
+ "# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection\n",
+ "print(\"Mean: \", torch.mean(embedding_collection(ids), dim=1))"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embedding Bag Collection Results: \n",
+ "tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n",
+ "Shape: torch.Size([1, 4])\n",
+ "Mean: tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4643305e-2770-40cf-afc6-e64cd3f51063",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "SuCV1cJ8cso_"
+ },
+ "source": [
+ "Congratulations! Now you have a basic understanding on how to use embedding tables --- one of the foundations of modern recommendation systems! These tables represent entities and their relationships. For example, the relationship between a given user and the pages & posts they have liked."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "7dfcffeb-c7c0-4d74-9dba-569c1d882898",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "QIuAYSZ5cso_"
+ },
+ "source": [
+ "# TorchRec\n",
+ "\n",
+ "Now you know how to use embedding tables, one of the foundations of modern recommendation systems! These tables represent entities and relationships, such as users, pages, posts, etc. Given that these entities are always increasing, a **hash** function is typically applied to make sure the ids are within the bounds of a certain embedding table. However, in order to represent a vast amount of entities and reduce hash collisions, these tables can become quite massive (think about # of ads for example). In fact, these tables can become so massive that they won't be able to fit on 1 GPU, even with 80G of memory!\n",
+ "\n",
+ "In order to train models with massive embedding tables, sharding these tables across GPUs is required, which then introduces a whole new set of problems/opportunities in parallelism and optimization. Luckily, we have the TorchRec library that has encountered, consolidated, and addressed many of these concerns. TorchRec serves as a **library that provides primitives for large scale distributed embeddings**.\n",
+ "\n",
+ "From here on out, we will explore the major features of the TorchRec library. We will start with torch.nn.Embedding and will extend that to custom TorchRec modules, explore distributed training environment with generating a sharding plan for embeddings, look at inherent TorchRec optimizations, and extend the model to be ready for inference in C++. Below is a quick outline of what the journey will consist of - buckle in!\n",
+ "\n",
+ "1. TorchRec Modules and DataTypes\n",
+ "2. Distributed Training, Sharding, and Optimizations\n",
+ "3. Inference\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "8395ed9c-8336-4686-8e73-cb815b808f2a",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000134724,
+ "executionStopTime": 1726000139238,
+ "serverExecutionDuration": 4317.9145939648,
+ "requestMsgId": "8395ed9c-8336-4686-8e73-cb815b808f2a",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "5vzmNV0IcspA"
+ },
+ "source": [
+ "import torchrec"
+ ],
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "0c95b385-e07a-43e1-aaeb-31f66deb5b35",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "42PwMZnNcspA"
+ },
+ "source": [
+ "## TorchRec Modules and Datatypes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZdSUWBRxoP8R",
+ "originalKey": "309c4d38-8f19-46d9-a8bb-7d3d1c166e84",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "### From EmbeddingBag to EmbeddingBagCollection\n",
+ "We have already explored [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).\n",
+ "\n",
+ "TorchRec extends these modules by creating collections of embeddings, in other words modules that can have multiple embedding tables, with [`EmbeddingCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection) and [`EmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection). We will use `EmbeddingBagCollection` to represent a group of EmbeddingBags.\n",
+ "\n",
+ "Here, we create an EmbeddingBagCollection (EBC) with two embedding bags, 1 representing **products** and 1 representing **users**. Each table, `product_table` and `user_table`, is represented by 64 dimension embedding of size 4096."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Iz_GZDp_oQ19",
+ "originalKey": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000139247,
+ "executionStopTime": 1726000139433,
+ "serverExecutionDuration": 13.643965125084,
+ "requestMsgId": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
+ "output": {
+ "id": "1615870128957785"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "cec92f4f-d9eb-464e-8e22-26fd910bf8c1"
+ },
+ "source": [
+ "ebc = torchrec.EmbeddingBagCollection(\n",
+ " device=\"cpu\",\n",
+ " tables=[\n",
+ " torchrec.EmbeddingBagConfig(\n",
+ " name=\"product_table\",\n",
+ " embedding_dim=64,\n",
+ " num_embeddings=4096,\n",
+ " feature_names=[\"product\"],\n",
+ " pooling=torchrec.PoolingType.SUM,\n",
+ " ),\n",
+ " torchrec.EmbeddingBagConfig(\n",
+ " name=\"user_table\",\n",
+ " embedding_dim=64,\n",
+ " num_embeddings=4096,\n",
+ " feature_names=[\"user\"],\n",
+ " pooling=torchrec.PoolingType.SUM,\n",
+ " )\n",
+ " ]\n",
+ ")\n",
+ "print(ebc.embedding_bags)"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "c587a298-4d38-4a69-89a2-5d5c4a26cc2c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "xjcA0Di1cspA"
+ },
+ "source": [
+ "Let’s inspect the forward method for EmbeddingBagcollection and the module’s inputs and outputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c9d2717b-b753-4e0b-97bd-1596123d081d",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139437,
+ "executionStopTime": 1726000139616,
+ "serverExecutionDuration": 6.011176854372,
+ "requestMsgId": "c9d2717b-b753-4e0b-97bd-1596123d081d",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "398959426640405"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UuIrEWupcspA",
+ "outputId": "300c86c2-82c1-4657-fa1a-6a319eb40177"
+ },
+ "source": [
+ "import inspect\n",
+ "\n",
+ "# Let's look at the EmbeddingBagCollection forward method\n",
+ "# What is a KeyedJaggedTensor and KeyedTensor?\n",
+ "print(inspect.getsource(ebc.forward))"
+ ],
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " features (KeyedJaggedTensor): KJT of form [F X B X L].\n",
+ "\n",
+ " Returns:\n",
+ " KeyedTensor\n",
+ " \"\"\"\n",
+ " if is_non_strict_exporting() and not torch.jit.is_scripting():\n",
+ " return self._non_strict_exporting_forward(features)\n",
+ " flat_feature_names: List[str] = []\n",
+ " for names in self._feature_names:\n",
+ " flat_feature_names.extend(names)\n",
+ " inverse_indices = reorder_inverse_indices(\n",
+ " inverse_indices=features.inverse_indices_or_none(),\n",
+ " feature_names=flat_feature_names,\n",
+ " )\n",
+ " pooled_embeddings: List[torch.Tensor] = []\n",
+ " feature_dict = features.to_dict()\n",
+ " for i, embedding_bag in enumerate(self.embedding_bags.values()):\n",
+ " for feature_name in self._feature_names[i]:\n",
+ " f = feature_dict[feature_name]\n",
+ " res = embedding_bag(\n",
+ " input=f.values(),\n",
+ " offsets=f.offsets(),\n",
+ " per_sample_weights=f.weights() if self._is_weighted else None,\n",
+ " ).float()\n",
+ " pooled_embeddings.append(res)\n",
+ " return KeyedTensor(\n",
+ " keys=self._embedding_names,\n",
+ " values=process_pooled_embeddings(\n",
+ " pooled_embeddings=pooled_embeddings,\n",
+ " inverse_indices=inverse_indices,\n",
+ " ),\n",
+ " length_per_key=self._lengths_per_embedding,\n",
+ " )\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "d6b9bfc2-544d-499f-ad61-d7471b819f8a",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "C_UAtHsMcspA"
+ },
+ "source": [
+ "### TorchRec Input/Output Data Types\n",
+ "TorchRec has distinct data types for input and output of its modules: `JaggedTensor`, `KeyedJaggedTensor`, and `KeyedTensor`. Now you might ask, why create new datatypes to represent sparse features? To answer that question, we must understand how sparse features are represented in code.\n",
+ "\n",
+ "Sparse features are otherwise known as `id_list_feature` and `id_score_list_feature`, and are the **IDs** that will be used as indices to an embedding table to retrieve the embedding for that ID. To give a very simple example, imagine a single sparse feature being Ads that a user interacted with. The input itself would be a set of Ad IDs that a user interacted with, and the embeddings retrieved would be a semantic representation of those Ads. The tricky part of representing these features in code is that in each input example, **the number of IDs is variable**. 1 day a user might have interacted with only 1 ad while the next day they interact with 3.\n",
+ "\n",
+ "A simple representation is shown below, where we have a `lengths` tensor denoting how many indices are in an example for a batch and a `values` tensor containing the indices themselves.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "13225ead-a798-4db2-8de6-1c13a758d676",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139620,
+ "executionStopTime": 1726000139790,
+ "serverExecutionDuration": 3.692839294672,
+ "requestMsgId": "13225ead-a798-4db2-8de6-1c13a758d676",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "RB77aL08cspA"
+ },
+ "source": [
+ "# Batch Size 2\n",
+ "# 1 ID in example 1, 2 IDs in example 2\n",
+ "id_list_feature_lengths = torch.tensor([1, 2])\n",
+ "\n",
+ "# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2\n",
+ "id_list_feature_values = torch.tensor([5, 7, 1])"
+ ],
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "65d31fca-7b7f-4c0f-9ca2-56e07243a5c0",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "aKmgGqdNcspA"
+ },
+ "source": [
+ "Let’s look at the offsets as well as what is contained in each Batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "9510cebd-1875-461e-9243-53928632abfa",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139794,
+ "executionStopTime": 1726000139966,
+ "serverExecutionDuration": 6.6289491951466,
+ "requestMsgId": "9510cebd-1875-461e-9243-53928632abfa",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "869913611744322"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "t5T5S8_mcspB",
+ "outputId": "87e78b11-7497-4387-c0bc-b4d277ba8ab3"
+ },
+ "source": [
+ "# Lengths can be converted to offsets for easy indexing of values\n",
+ "id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)\n",
+ "\n",
+ "print(\"Offsets: \", id_list_feature_offsets)\n",
+ "print(\"First Batch: \", id_list_feature_values[: id_list_feature_offsets[0]])\n",
+ "print(\n",
+ " \"Second Batch: \",\n",
+ " id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],\n",
+ ")"
+ ],
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Offsets: tensor([1, 3])\n",
+ "First Batch: tensor([5])\n",
+ "Second Batch: tensor([7, 1])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000139968,
+ "executionStopTime": 1726000140161,
+ "serverExecutionDuration": 7.3191449046135,
+ "requestMsgId": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1254783359215069"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2OOK2BBecspB",
+ "outputId": "b27c1547-fbb7-47c4-efb6-48aff3300d1a"
+ },
+ "source": [
+ "from torchrec import JaggedTensor\n",
+ "\n",
+ "# JaggedTensor is just a wrapper around lengths/offsets and values tensors!\n",
+ "jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)\n",
+ "\n",
+ "# Automatically compute offsets from lengths\n",
+ "print(\"Offsets: \", jt.offsets())\n",
+ "\n",
+ "# Convert to list of values\n",
+ "print(\"List of Values: \", jt.to_dense())\n",
+ "\n",
+ "# __str__ representation\n",
+ "print(jt)"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Offsets: tensor([0, 1, 3])\n",
+ "List of Values: [tensor([5]), tensor([7, 1])]\n",
+ "JaggedTensor({\n",
+ " [[5], [7, 1]]\n",
+ "})\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "ad069058-2329-4ab9-bee8-60775ead4c33",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140165,
+ "executionStopTime": 1726000140355,
+ "serverExecutionDuration": 10.361641645432,
+ "requestMsgId": "ad069058-2329-4ab9-bee8-60775ead4c33",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "530006499497328"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fs10Fxu2cspB",
+ "outputId": "3bc754c2-ac30-4c01-b0c8-7f98eedb9c52"
+ },
+ "source": [
+ "from torchrec import KeyedJaggedTensor\n",
+ "\n",
+ "# JaggedTensor represents IDs for 1 feature, but we have multiple features in an EmbeddingBagCollection\n",
+ "# That's where KeyedJaggedTensor comes in! KeyedJaggedTensor is just multiple JaggedTensors for multiple id_list_feature_offsets\n",
+ "# From before, we have our two features \"product\" and \"user\". Let's create JaggedTensors for both!\n",
+ "\n",
+ "product_jt = JaggedTensor(\n",
+ " values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])\n",
+ ")\n",
+ "user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))\n",
+ "\n",
+ "# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?\n",
+ "kjt = KeyedJaggedTensor.from_jt_dict({\"product\": product_jt, \"user\": user_jt})\n",
+ "\n",
+ "# Look at our feature keys for the KeyedJaggedTensor\n",
+ "print(\"Keys: \", kjt.keys())\n",
+ "\n",
+ "# Look at the overall lengths for the KeyedJaggedTensor\n",
+ "print(\"Lengths: \", kjt.lengths())\n",
+ "\n",
+ "# Look at all values for KeyedJaggedTensor\n",
+ "print(\"Values: \", kjt.values())\n",
+ "\n",
+ "# Can convert KJT to dictionary representation\n",
+ "print(\"to_dict: \", kjt.to_dict())\n",
+ "\n",
+ "# KeyedJaggedTensor(KJT) string representation\n",
+ "print(kjt)\n",
+ "\n",
+ "# Q2: What are the offsets for the KeyedJaggedTensor?"
+ ],
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Keys: ['product', 'user']\n",
+ "Lengths: tensor([3, 1, 2, 2])\n",
+ "Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])\n",
+ "to_dict: {'product': , 'user': }\n",
+ "KeyedJaggedTensor({\n",
+ " \"product\": [[1, 2, 1], [5]],\n",
+ " \"user\": [[2, 3], [4, 1]]\n",
+ "})\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140357,
+ "executionStopTime": 1726000140549,
+ "serverExecutionDuration": 17.695877701044,
+ "requestMsgId": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "496557126663787"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JeLwyCNRcspB",
+ "outputId": "0723906d-0aba-4d48-e9a5-7ac618d711c5"
+ },
+ "source": [
+ "# Now we can run a forward pass on our ebc from before\n",
+ "result = ebc(kjt)\n",
+ "result"
+ ],
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 14
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "57a01464-de39-4bfb-8355-83cd97e519c0",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000140552,
+ "executionStopTime": 1726000140732,
+ "serverExecutionDuration": 6.0368701815605,
+ "requestMsgId": "57a01464-de39-4bfb-8355-83cd97e519c0",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1457290878317732"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "R2K4v2vqcspB",
+ "outputId": "c507e47e-32bc-4440-8c8a-2b3ea334467c"
+ },
+ "source": [
+ "# Result is a KeyedTensor, which contains a list of the feature names and the embedding results\n",
+ "print(result.keys())\n",
+ "\n",
+ "# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined\n",
+ "# 128 for dimension of embedding. If you look at where we initialized the EmbeddingBagCollection, we have two tables \"product\" and \"user\" of dimension 64 each\n",
+ "# meaning emebddings for both features are of size 64. 64 + 64 = 128\n",
+ "print(result.values().shape)\n",
+ "\n",
+ "# Nice to_dict method to determine the embeddings that belong to each feature\n",
+ "result_dict = result.to_dict()\n",
+ "for key, embedding in result_dict.items():\n",
+ " print(key, embedding.shape)"
+ ],
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "['product', 'user']\n",
+ "torch.Size([2, 128])\n",
+ "product torch.Size([2, 64])\n",
+ "user torch.Size([2, 64])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "d0fc8635-dac3-444b-978b-421b5d77b70c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "EE-YYDv7cspB"
+ },
+ "source": [
+ "Congrats! Give yourself a pat on the back for making it this far."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "70816a78-7671-411c-814f-d2c98c3a912c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "djLHn0CIcspB"
+ },
+ "source": [
+ "## Distributed Training and Sharding\n",
+ "Now that we have a grasp on TorchRec modules and data types, it's time to take it to the next level.\n",
+ "\n",
+ "Remember, TorchRec's main purpose is to provide primitives for distributed embeddings. So far, we've only worked with embedding tables on 1 device. This has been possible given how small the embedding tables have been, but in a production setting this isn't generally the case. Embedding tables often get massive, where 1 table can't fit on a single GPU, creating the requirement for multiple devices and a distributed environment\n",
+ "\n",
+ "In this section, we will explore setting up a distributed environment, exactly how actual production training is done, and explore sharding embedding tables, all with Torchrec.\n",
+ "\n",
+ "**This section will also only use 1 gpu, though it will be treated in a distributed fashion. This is only a limitation for training, as training has a process per gpu. Inference does not run into this requirement**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4-v17rxkopQw",
+ "originalKey": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
+ "outputsInitialized": true,
+ "language": "python",
+ "customOutput": null,
+ "executionStartTime": 1726000140740,
+ "executionStopTime": 1726000142256,
+ "serverExecutionDuration": 1350.0418178737,
+ "requestMsgId": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
+ "output": {
+ "id": "1195358511578142"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "efa5689f-b794-41a6-8c06-86856e1e698a"
+ },
+ "source": [
+ "# Here we set up our torch distributed environment\n",
+ "# WARNING: You can only call this cell once, calling it again will cause an error\n",
+ "# as you can only initialize the process group once\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "import torch.distributed as dist\n",
+ "\n",
+ "# Set up environment variables for distributed training\n",
+ "# RANK is which GPU we are on, default 0\n",
+ "os.environ[\"RANK\"] = \"0\"\n",
+ "# How many devices in our \"world\", since Bento can only handle 1 process, 1 GPU\n",
+ "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
+ "# Localhost as we are training locally\n",
+ "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ "# Port for distributed training\n",
+ "os.environ[\"MASTER_PORT\"] = \"29500\"\n",
+ "\n",
+ "# Note - you will need a V100 or A100 to run tutorial as!\n",
+ "# nccl backend is for GPUs, gloo is for CPUs\n",
+ "dist.init_process_group(backend=\"gloo\")\n",
+ "\n",
+ "print(f\"Distributed environment initialized: {dist}\")"
+ ],
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Distributed environment initialized: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "480e69dc-3e9d-4e86-b73c-950e18afb0f5",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "hQOjNci3cspB"
+ },
+ "source": [
+ "### Distributed Embeddings\n",
+ "\n",
+ "We have already worked with the main TorchRec module: `EmbeddingBagCollection`. We have examined how it works along with how data is represented in TorchRec. However, we have not yet explored one of the main parts of TorchRec, which is **distributed embeddings**.\n",
+ "\n",
+ "GPUs are the most popular choice for ML workloads by far today, as they are able to do magnitudes more floating point operations/s ([FLOPs](https://en.wikipedia.org/wiki/FLOPS)) than CPU. However, GPUs come with the limitation of scarce fast memory (HBM which is analgous to RAM for CPU), typically ~10s of GBs.\n",
+ "\n",
+ "A RecSys model can contain embedding tables that far exceed the memory limit for 1 GPU, hence the need for distribution of the embedding tables across multiple GPUs, otherwise known as **model parallel**. On the other hand, **data parallel** is where the entire model is replicated on each GPU, which each GPU taking in a distinct batch of data for training, syncing gradients on the backwards pass.\n",
+ "\n",
+ "Parts of the model that **require less compute but more memory (embeddings) are distributed with model parallel** while parts that **require more compute and less memory (dense layers, MLP, etc.) are distributed with data parallel**.\n",
+ "\n",
+ "\n",
+ "### Sharding\n",
+ "In order to distribute an embedding table, we split up the embedding table into parts and place those parts onto different devices, also known as “sharding”.\n",
+ "\n",
+ "There are many ways to shard embedding tables. The most common ways are:\n",
+ "* Table-Wise: the table is placed entirely onto one device\n",
+ "* Column-Wise: columns of embedding tables are sharded\n",
+ "* Row-Wise: rows of embedding tables are sharded\n",
+ "\n",
+ "\n",
+ "### Sharded Modules\n",
+ "While all of this seems like a lot to deal with and implement, you're in luck. **TorchRec provides all the primitives for easy distributed training/inference**! In fact, TorchRec modules have two corresponding classes for working with any TorchRec module in a distributed environment:\n",
+ "1. The module sharder: This class exposes a `shard` API that handles sharding a TorchRec Module, producing a sharded module.\n",
+ " * For `EmbeddingBagCollection`, the sharder is [`EmbeddingBagCollectionSharder`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder)\n",
+ "2. Sharded module: This class is a sharded variant of a TorchRec module. It has the same input/output as a the regular TorchRec module, but much more optimized and works in a distributed environment.\n",
+ " * For `EmbeddingBagCollection`, the sharded variant is [`ShardedEmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection)\n",
+ "\n",
+ "Every TorchRec module has an unsharded and sharded variant.\n",
+ "* The unsharded version is meant to be prototyped and experimented with\n",
+ "* The sharded version is meant to be used in a distributed environment for distributed training/inference.\n",
+ "\n",
+ "The sharded versions of TorchRec modules, for example EmbeddingBagCollection, will handle everything that is needed for Model Parallelism, such as communication between GPUs for distributing embeddings to the correct GPUs.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000142261,
+ "executionStopTime": 1726000142430,
+ "serverExecutionDuration": 8.3460621535778,
+ "requestMsgId": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "791089056311464"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FX65VcQ6cspB",
+ "outputId": "1aa3bc52-569b-46fb-8a94-cd1873e987ca"
+ },
+ "source": [
+ "# Refresher of our EmbeddingBagCollection module\n",
+ "ebc"
+ ],
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "1442636d-8617-4785-b40c-8544374253b6",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "outputsInitialized": true,
+ "executionStartTime": 1726000142433,
+ "executionStopTime": 1726000142681,
+ "serverExecutionDuration": 4.4135116040707,
+ "requestMsgId": "1442636d-8617-4785-b40c-8544374253b6",
+ "customOutput": null,
+ "output": {
+ "id": "502189589096046"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1hSzTg4pcspC",
+ "outputId": "d7d86592-4fdc-4d0b-f2ba-a40a80af1fcf"
+ },
+ "source": [
+ "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n",
+ "from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\n",
+ "from torchrec.distributed.types import ShardingEnv\n",
+ "\n",
+ "# Corresponding sharder for EmbeddingBagCollection module\n",
+ "sharder = EmbeddingBagCollectionSharder()\n",
+ "\n",
+ "# ProcessGroup from torch.distributed initialized 2 cells above\n",
+ "pg = dist.GroupMember.WORLD\n",
+ "assert pg is not None, \"Process group is not initialized\"\n",
+ "\n",
+ "print(f\"Process Group: {pg}\")"
+ ],
+ "execution_count": 18,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Process Group: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "29cc17eb-9e2f-480b-aed2-60b15024fbf7",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "qU7A980qcspC"
+ },
+ "source": [
+ "### Planner\n",
+ "\n",
+ "Before we can show how sharding works, we must know about the **planner**, which helps us determine the best sharding configuration.\n",
+ "\n",
+ "Given a number of embedding tables and a number of ranks, there are many different sharding configurations that are possible. For example, given 2 embedding tables and 2 GPUs, you can:\n",
+ "* Place 1 table on each GPU\n",
+ "* Place both tables on a single GPU and no tables on the other\n",
+ "* Place certain rows/columns on each GPU\n",
+ "\n",
+ "Given all of these possibilities, we typically want a sharding configuration that is optimal for performance.\n",
+ "\n",
+ "That is where the planner comes in. The planner is able to determine given the # of embedding tables and the # of GPUs, what is the optimal configuration. Turns out, this is incredibly difficult to do manually, with tons of factors that engineers have to consider to ensure an optimal sharding plan. Luckily, TorchRec provides an auto planner when the planner is used. The TorchRec planner:\n",
+ "* assesses memory constraints of hardware,\n",
+ "* estimates compute based on memory fetches as embedding lookups,\n",
+ "* addresses data specific factors,\n",
+ "* considers other hardware specifics like bandwidth to generate an optimal sharding plan.\n",
+ "\n",
+ "In order to take into consideration all these variables, The TorchRec planner can take in [various amounts of data for embedding tables, constraints, hardware information, and topology](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155) to aid in generating the optimal sharding plan for a model, which is routinely provided across stacks\n",
+ "\n",
+ "\n",
+ "To learn more about sharding, see our [sharding tutorial](https://pytorch.org/tutorials/advanced/sharding.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "64936203-2e59-4bc3-8d76-1b652b7891c2",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000142687,
+ "executionStopTime": 1726000143033,
+ "serverExecutionDuration": 145.92137932777,
+ "requestMsgId": "64936203-2e59-4bc3-8d76-1b652b7891c2",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1247084956198777"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PQeXnuAGcspC",
+ "outputId": "e6d366b7-f064-4e38-a477-fc2e0d4b9736"
+ },
+ "source": [
+ "# In our case, 1 GPU and compute on CUDA device\n",
+ "planner = EmbeddingShardingPlanner(\n",
+ " topology=Topology(\n",
+ " world_size=1,\n",
+ " compute_device=\"cuda\",\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# Run planner to get plan for sharding\n",
+ "plan = planner.collective_plan(ebc, [sharder], pg)\n",
+ "\n",
+ "print(f\"Sharding Plan generated: {plan}\")"
+ ],
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharding Plan generated: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "product_table | table_wise | fused | [0] \n",
+ "user_table | table_wise | fused | [0] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "product_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "user_table | [0, 0] | [4096, 64] | rank:0/cuda:0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "bbbbbf60-5691-4357-9943-4d7f8b2b1d5c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "2TTLj_0PcspC"
+ },
+ "source": [
+ "### Planner Result\n",
+ "As you can see, when running the planner there is quite a bit of output above. We can see a ton of stats being calculated along with where our tables end up being placed.\n",
+ "\n",
+ "The result of running the planner is a static plan, which can be reused for sharding! This allows sharding to be static for production models instead of determining a new sharding plan everytime. Below, we use the sharding plan to finally generate our `ShardedEmbeddingBagCollection.`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000143037,
+ "executionStopTime": 1726000143259,
+ "serverExecutionDuration": 5.2368640899658,
+ "requestMsgId": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "901470115170971"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JIci5Gz6cspC",
+ "outputId": "32d3cb73-80b6-4646-9c1d-3cff5a498f86"
+ },
+ "source": [
+ "# The static plan that was generated\n",
+ "plan"
+ ],
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "ShardingPlan(plan={'': {'product_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None), 'user_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None)}})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 20
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000143262,
+ "executionStopTime": 1726000147680,
+ "serverExecutionDuration": 4229.5375689864,
+ "requestMsgId": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1231077634880712"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2__Do2tqcspC",
+ "outputId": "7c19830a-9a11-4fbe-ddf3-8c3cb8a6c3b3"
+ },
+ "source": [
+ "env = ShardingEnv.from_process_group(pg)\n",
+ "\n",
+ "# Shard the EmbeddingBagCollection module using the EmbeddingBagCollectionSharder\n",
+ "sharded_ebc = sharder.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "print(f\"Sharded EBC Module: {sharded_ebc}\")"
+ ],
+ "execution_count": 21,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharded EBC Module: ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [],
+ "metadata": {
+ "id": "ErXXbYzJmVzI"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "3ba44a6d-a6f7-4da2-83a6-e8ac974c64ac",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "QBLpkKYIcspC"
+ },
+ "source": [
+ "#### Awaitable\n",
+ "Remember that TorchRec is a highly optimized library for distributed embeddings. A concept that TorchRec introduces to enable higher performance for training on GPU is a [`LazyAwaitable`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable). You will see `LazyAwaitable` types as outputs of various sharded TorchRec modules. All a `LazyAwaitable` does is delay calculating some result as long as possible, and it does it by acting like an async type."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000147687,
+ "executionStopTime": 1726000147874,
+ "serverExecutionDuration": 9.098757058382,
+ "requestMsgId": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1236006950908310"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rwYzKwyNcspC",
+ "outputId": "a8979bb2-6fac-4db6-b997-c31962014e24"
+ },
+ "source": [
+ "from typing import List\n",
+ "\n",
+ "from torchrec.distributed.types import LazyAwaitable\n",
+ "\n",
+ "\n",
+ "# Demonstrate a LazyAwaitable type\n",
+ "class ExampleAwaitable(LazyAwaitable[torch.Tensor]):\n",
+ " def __init__(self, size: List[int]) -> None:\n",
+ " super().__init__()\n",
+ " self._size = size\n",
+ "\n",
+ " def _wait_impl(self) -> torch.Tensor:\n",
+ " return torch.ones(self._size)\n",
+ "\n",
+ "\n",
+ "awaitable = ExampleAwaitable([3, 2])\n",
+ "awaitable.wait()"
+ ],
+ "execution_count": 22,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[1., 1.],\n",
+ " [1., 1.],\n",
+ " [1., 1.]])"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 22
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000147878,
+ "executionStopTime": 1726000154861,
+ "serverExecutionDuration": 6806.3651248813,
+ "requestMsgId": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1255627342282843"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "cs41RfzGcspC",
+ "outputId": "ba7cfdbe-59c4-48c2-a767-d6f5f5cbb915"
+ },
+ "source": [
+ "kjt = kjt.to(\"cuda\")\n",
+ "output = sharded_ebc(kjt)\n",
+ "# The output of our sharded EmbeddingBagCollection module is a an Awaitable?\n",
+ "print(output)"
+ ],
+ "execution_count": 23,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000154865,
+ "executionStopTime": 1726000155069,
+ "serverExecutionDuration": 6.0432851314545,
+ "requestMsgId": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1057638405967561"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_1sdt75rcspG",
+ "outputId": "365cf6cb-adae-4b41-f68d-1a7acf5fe332"
+ },
+ "source": [
+ "kt = output.wait()\n",
+ "# Now we have out KeyedTensor after calling .wait()\n",
+ "# If you are confused as to why we have a KeyedTensor output,\n",
+ "# give yourself a refresher on the unsharded EmbeddingBagCollection module\n",
+ "print(type(kt))\n",
+ "\n",
+ "print(kt.keys())\n",
+ "\n",
+ "print(kt.values().shape)\n",
+ "\n",
+ "# Same output format as unsharded EmbeddingBagCollection\n",
+ "result_dict = kt.to_dict()\n",
+ "for key, embedding in result_dict.items():\n",
+ " print(key, embedding.shape)"
+ ],
+ "execution_count": 24,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "['product', 'user']\n",
+ "torch.Size([2, 128])\n",
+ "product torch.Size([2, 64])\n",
+ "user torch.Size([2, 64])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4c464de0-20ef-4ef2-89e2-5d58ca224660",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "bEgB987CcspG"
+ },
+ "source": [
+ "### Anatomy of Sharded TorchRec modules\n",
+ "\n",
+ "We have now successfully sharded an EmbeddingBagCollection given a sharding plan that we generated! The sharded module has common APIs from TorchRec which abstract away distributed communication/compute amongst multiple GPUs. In fact, these APIs are highly optimized for performance in training and inference. **Below are the three common APIs for distributed training/inference** that are provided by TorchRec:\n",
+ "\n",
+ "1. **input_dist**: Handles distributing inputs from GPU to GPU\n",
+ "\n",
+ "2. **lookups**: Does the actual embedding lookup in an optimized, batched manner using FBGEMM TBE (more on this later)\n",
+ "\n",
+ "3. **output_dist**: Handles distributing outputs from GPU to GPU\n",
+ "\n",
+ "The distribution of inputs/outputs is done through [NCCL Collectives](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html), namely [All-to-Alls](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all), which is where all GPUs send/receive data to and from one another. TorchRec interfaces with PyTorch distributed for collectives and provides clean abstractions to the end users, removing the concern for the lower level details.\n",
+ "\n",
+ "\n",
+ "The backwards pass does all of these collectives but in the reverse order for distribution of gradients. input_dist, lookup, and output_dist all depend on the sharding scheme. Since we sharded in a table-wise fashion, these APIs are modules that are constructed by [TwPooledEmbeddingSharding](https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155075,
+ "executionStopTime": 1726000155253,
+ "serverExecutionDuration": 5.8192722499371,
+ "requestMsgId": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1042737524520351"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "O2ptES89cspG",
+ "outputId": "2b801648-6501-4463-d743-4887da340974"
+ },
+ "source": [
+ "sharded_ebc"
+ ],
+ "execution_count": 25,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_input_dists): \n",
+ " TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 25
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155256,
+ "executionStopTime": 1726000155442,
+ "serverExecutionDuration": 5.3565315902233,
+ "requestMsgId": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1063399165221115"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PHjJt3BQcspG",
+ "outputId": "aaedadd8-da43-4225-d5f8-a7a43fd0250a"
+ },
+ "source": [
+ "# Distribute input KJTs to all other GPUs and receive KJTs\n",
+ "sharded_ebc._input_dists"
+ ],
+ "execution_count": 26,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 26
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "88abe892-1ed1-4806-84ad-35f43247a772",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155445,
+ "executionStopTime": 1726000155695,
+ "serverExecutionDuration": 5.3521953523159,
+ "requestMsgId": "88abe892-1ed1-4806-84ad-35f43247a772",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1513800839249249"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jrEXMc7TcspG",
+ "outputId": "81ab40a5-135b-494c-f2bc-91be16a338cc"
+ },
+ "source": [
+ "# Distribute output embeddingts to all other GPUs and receive embeddings\n",
+ "sharded_ebc._output_dists"
+ ],
+ "execution_count": 27,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 27
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "2eaf16f1-ac14-4f7a-b443-e707ff85c3f0",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "C2jfo5ilcspH"
+ },
+ "source": [
+ "### Optimizing Embedding Lookups\n",
+ "\n",
+ "In performing lookups for a collection of embedding tables, a trivial solution would be to iterate through all the `nn.EmbeddingBags` and do a lookup per table. This is exactly what the standard, unsharded TorchRec's `EmbeddingBagCollection` does. However, while this solution is simple, it is extremely slow.\n",
+ "\n",
+ "[FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) is a library that provides GPU operators (otherewise known as kernels) that are very optimized. One of these operators is known as **Table Batched Embedding** (TBE), provides two major optimizations:\n",
+ "\n",
+ "* Table batching, which allows you to look up multiple embeddings with one kernel call.\n",
+ "* Optimizer Fusion, which allows the module to update itself given the canonical pytorch optimizers and arguments.\n",
+ "\n",
+ "The `ShardedEmbeddingBagCollection` uses the FBGEMM TBE as the lookup instead of traditional `nn.EmbeddingBags` for optimized embedding lookups."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155699,
+ "executionStopTime": 1726000155879,
+ "serverExecutionDuration": 5.0756596028805,
+ "requestMsgId": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "911093750838903"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1GoHWI6OcspH",
+ "outputId": "cd67815b-00bd-4a30-89cf-7b5d9c7051e9"
+ },
+ "source": [
+ "sharded_ebc._lookups"
+ ],
+ "execution_count": 28,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 28
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "f2b31d78-81a9-426f-b017-ca8404383939",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "1zcbZX1lcspH"
+ },
+ "source": [
+ "### DistributedModelParallel\n",
+ "\n",
+ "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/model-parallel-api-reference.html#model-parallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n",
+ "\n",
+ "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n",
+ "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n",
+ "\n",
+ "DMP takes in everything that we've just experimented with, like a static sharding plan, a list of sharders, etc. However, it also has some nice defaults to seamlessly shard a TorchRec model. In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000155883,
+ "executionStopTime": 1726000156073,
+ "serverExecutionDuration": 7.8761726617813,
+ "requestMsgId": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1207953610328397"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ypVUDwpzcspH",
+ "outputId": "26a7a957-c231-459a-dfc8-f0c1cd6f697e"
+ },
+ "source": [
+ "ebc"
+ ],
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 29
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "73fec38d-947a-49d5-a2ba-61e3828b7117",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156075,
+ "executionStopTime": 1726000156438,
+ "serverExecutionDuration": 165.43522849679,
+ "requestMsgId": "73fec38d-947a-49d5-a2ba-61e3828b7117",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1838328716783594"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "5EdlyWAycspH",
+ "outputId": "05e90aa2-cb83-4ddf-9da5-6aa31d6da278"
+ },
+ "source": [
+ "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))"
+ ],
+ "execution_count": 30,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156441,
+ "executionStopTime": 1726000156665,
+ "serverExecutionDuration": 6.8417005240917,
+ "requestMsgId": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1059040285804352"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "b5NgRErjcspH",
+ "outputId": "8f4de40c-3a3b-43e5-d645-814bf03dab0b"
+ },
+ "source": [
+ "out = model(kjt)\n",
+ "out.wait()"
+ ],
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 31
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000156669,
+ "executionStopTime": 1726000156885,
+ "serverExecutionDuration": 5.4804161190987,
+ "requestMsgId": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "3346626825643095"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VJrSysjjcspH",
+ "outputId": "2920a3d0-dd96-43ea-ab0d-627de00d1e42"
+ },
+ "source": [
+ "model"
+ ],
+ "execution_count": 32,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DistributedModelParallel(\n",
+ " (_dmp_wrapped_module): ShardedEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " GroupedPooledEmbeddingsLookup(\n",
+ " (_emb_modules): ModuleList(\n",
+ " (0): BatchedFusedEmbeddingBag(\n",
+ " (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_input_dists): \n",
+ " TwSparseFeaturesDist(\n",
+ " (_dist): KJTAllToAll()\n",
+ " )\n",
+ " (_output_dists): \n",
+ " TwPooledEmbeddingDist(\n",
+ " (_dist): PooledEmbeddingsAllToAll()\n",
+ " )\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 32
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "4b6171d5-ae60-4cc8-a47a-f01236c02e6c",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "BLM673eTcspH"
+ },
+ "source": [
+ "### Sharding Best Practices\n",
+ "\n",
+ "Currently, our configuration is only sharding on 1 GPU (or rank), which is trivial: just place all the tables on 1 GPUs memory. However, in real production use cases, embedding tables are **typically sharded on hundreds of GPUs**, with different sharding methods such as table-wise, row-wise, and column-wise. It is incredibly important to determine a proper sharding configuration (to prevent out of memory issues) while keeping it balanced not only in terms of memory but also compute for optimal performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Adding in the Optimizer\n",
+ "\n",
+ "Remember that TorchRec modules are hyperoptimized for large scale distributed training. An important optimization is in regards to the optimizer. **TorchRec modules provide a seamless API to fuse the backwards pass and optimize step in training, providing a significant optimization in performance and decreasing the memory used, alongside granularity in assigning distinct optimizers to distinct model parameters.**\n",
+ "\n",
+ "#### Optimizer Classes\n",
+ "\n",
+ "TorchRec uses `CombinedOptimizer`, which contains a collection of `KeyedOptimizers`. A `CombinedOptimizer` effectively makes it easy to handle multiple optimizers for various sub groups in the model. A `KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized through a dictionary of parameters exposes the parameters. Each `TBE` module in a `EmbeddingBagCollection` will have it's own `KeyedOptimizer` which combines into one `CombinedOptimizer`.\n",
+ "\n",
+ "#### Fused optimizer in TorchRec\n",
+ "\n",
+ "Using `DistributedModelParallel`, the **optimizer is fused, which means that the optimizer update is done in the backward**. This is an optimization in TorchRec and FBGEMM, where the optimizer embedding gradients are not materialized and applied directly to the parameters. This brings significant memory savings as embedding gradients are typically size of the parameters themselves.\n",
+ "\n",
+ "You can, however, choose to make the optimizer `dense` which does not apply this optimization and let's you inspect the embedding gradients or apply computations to it as you wish. A dense optimizer in this case would be your [canonical PyTorch model training loop with optimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n",
+ "\n",
+ "Once the optimizer is created through `DistributedModelParallel`, you still need to manage an optimizer for the other parameters not associated with TorchRec embedding modules. To find the other parameters, use`in_backward_optimizer_filter(model.named_parameters())`.\n",
+ "\n",
+ "Apply an optimizer to those parameters as you would a normal Torch optimizer and combine this and the `model.fused_optimizer` into one `CombinedOptimizer` that you can use in your training loop to `zero_grad` and `step` through.\n",
+ "\n",
+ "#### Let's add an optimizer to our EmbeddingBagCollection\n",
+ "We will do this in two ways, which are equivalent, but give you options depending on your preferences:\n",
+ "1. Passing optimizer kwargs through fused parameters (fused_params) in sharder\n",
+ "2. Through `apply_optimizer_in_backward`\n",
+ "Note: `apply_optimizer_in_backward` converts the optimizer parameters to `fused_params` to pass to the `TBE` in the `EmbeddingBagCollection`/`EmbeddingCollection`."
+ ],
+ "metadata": {
+ "id": "zFhggkUCmd7f"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Approach 1: passing optimizer kwargs through fused parameters\n",
+ "from torchrec.optim.optimizers import in_backward_optimizer_filter\n",
+ "from fbgemm_gpu.split_embedding_configs import EmbOptimType\n",
+ "\n",
+ "\n",
+ "# We initialize the sharder with\n",
+ "fused_params = {\n",
+ " \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n",
+ " \"learning_rate\": 0.02,\n",
+ " \"eps\": 0.002,\n",
+ "}\n",
+ "\n",
+ "# Init sharder with fused_params\n",
+ "sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n",
+ "\n",
+ "# We'll use same plan and unsharded EBC as before but this time with our new sharder\n",
+ "sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty.\n",
+ "# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\n",
+ "print(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\n",
+ "print(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n",
+ "\n",
+ "print(f\"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "h5BCEFidmnEw",
+ "outputId": "202c64f7-ae95-4b0d-9f53-16138a680d7d"
+ },
+ "execution_count": 33,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.01\n",
+ ")\n",
+ "Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.02\n",
+ ")\n",
+ "Type of optimizer: \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n",
+ "import copy\n",
+ "# Approach 2: applying optimizer through apply_optimizer_in_backward\n",
+ "# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n",
+ "\n",
+ "# We can achieve the same result as we did in the previous\n",
+ "ebc_apply_opt = copy.deepcopy(ebc)\n",
+ "optimizer_kwargs = {\"lr\": 0.5}\n",
+ "\n",
+ "for name, param in ebc_apply_opt.named_parameters():\n",
+ " print(f\"{name=}\")\n",
+ " apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n",
+ "\n",
+ "sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
+ "\n",
+ "# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\n",
+ "print(sharded_ebc_apply_opt.fused_optimizer)\n",
+ "print(type(sharded_ebc_apply_opt.fused_optimizer))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "T-xx724MmoKv",
+ "outputId": "0f58fb18-f423-4c84-ee57-d37bdba28eb8"
+ },
+ "execution_count": 34,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "name='embedding_bags.product_table.weight'\n",
+ "name='embedding_bags.user_table.weight'\n",
+ ": EmbeddingFusedOptimizer (\n",
+ "Parameter Group 0\n",
+ " lr: 0.5\n",
+ ")\n",
+ "\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":1: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.\n",
+ " from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# We can also check through the filter other parameters that aren't associated with the \"fused\" optimizer(s)\n",
+ "# Pratically, just non TorchRec module parameters. Since our module is just a TorchRec EBC\n",
+ "# there are no other parameters that aren't associated with TorchRec\n",
+ "print(\"Non Fused Model Parameters:\")\n",
+ "print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UEyhlmbmlwsW",
+ "outputId": "f6219673-d14d-444e-a451-98f33ddeb54d"
+ },
+ "execution_count": 35,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Non Fused Model Parameters:\n",
+ "dict_keys([])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Here we do a dummy backwards call and see that parameter updates for fused\n",
+ "# optimizers happen as a result of the backward pass\n",
+ "\n",
+ "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
+ "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
+ "print(f\"First Iteration Loss: {loss}\")\n",
+ "\n",
+ "loss.backward()\n",
+ "\n",
+ "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
+ "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
+ "# We don't call an optimizer.step(), so for the loss to have changed here,\n",
+ "# that means that the gradients were somehow updated, which is what the\n",
+ "# fused optimizer automatically handles for us\n",
+ "print(f\"Second Iteration Loss: {loss}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bga-zM2OfnMW",
+ "outputId": "6c5d45f2-c479-4932-b39d-5ff8abe27d3c"
+ },
+ "execution_count": 36,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "First Iteration Loss: 255.94378662109375\n",
+ "Second Iteration Loss: 245.72166442871094\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "e3bdc895-54c4-4fc6-9175-28dd75021c6a",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "Xc-RUDwDcspH"
+ },
+ "source": [
+ "## Inference\n",
+ "\n",
+ "Now that we are able to train distributed embeddings, how can we take the trained model and optimize it for inference? Inference is typically very sensitive to **performance and size of the model**. Running just the trained model in a Python environment is incredibly inefficient. There are two key differences between inference and training environments:\n",
+ "* **Quantization**: Inference models are typically quantized, where model parameters lose precision for lower latency in predictions and reduced model size. For example FP32 (4 bytes) in trained model to INT8 (1 byte) for each embedding weight. This is also necessary given the vast scale of embedding tables, as we want to use as few devices as possible for inference to minimize latency.\n",
+ "* **C++ environment**: Inference latency is a big deal, so in order to ensure ample performance, the model is typically ran in a C++ environment (along with situations where we don't have a Python runtime, like on device)\n",
+ "\n",
+ "TorchRec provides primitives for converting a TorchRec model into being inference ready with:\n",
+ "* APIs for quantizing the model, introducing optimizations automatically with FBGEMM TBE\n",
+ "* sharding embeddings for distributed inference\n",
+ "* compiling the model to [TorchScript](https://pytorch.org/docs/stable/jit.html) (compatible in C++)\n",
+ "\n",
+ "In this section, we will go over this entire workflow of:\n",
+ "* Quantizing the model\n",
+ "* Sharding the quantized model\n",
+ "* Compiling the sharded quantized model into TorchScript"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "outputsInitialized": true,
+ "executionStartTime": 1726000156892,
+ "executionStopTime": 1726000157069,
+ "serverExecutionDuration": 7.4504055082798,
+ "requestMsgId": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
+ "customOutput": null,
+ "output": {
+ "id": "456742254014129"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "8JypsUNmcspH",
+ "outputId": "0a745234-d316-4850-d84a-f08b0f045595"
+ },
+ "source": [
+ "ebc"
+ ],
+ "execution_count": 37,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 37
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "30694976-da54-48d6-922e-ca53f22c385f",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157071,
+ "executionStopTime": 1726000157317,
+ "serverExecutionDuration": 2.9501467943192,
+ "requestMsgId": "30694976-da54-48d6-922e-ca53f22c385f",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "t2plfyrWcspH"
+ },
+ "source": [
+ "class InferenceModule(torch.nn.Module):\n",
+ " def __init__(self, ebc: torchrec.EmbeddingBagCollection):\n",
+ " super().__init__()\n",
+ " self.ebc_ = ebc\n",
+ "\n",
+ " def forward(self, kjt: KeyedJaggedTensor):\n",
+ " return self.ebc_(kjt)"
+ ],
+ "execution_count": 38,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157320,
+ "executionStopTime": 1726000157494,
+ "serverExecutionDuration": 3.8229525089264,
+ "requestMsgId": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1619365005294308"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "5FRioGEmcspH",
+ "outputId": "e4d9efd3-2427-4602-bb28-2f30c4f3f985"
+ },
+ "source": [
+ "module = InferenceModule(ebc)\n",
+ "for name, param in module.named_parameters():\n",
+ " # Here, the parameters should still be FP32, as we are using a standard EBC\n",
+ " # FP32 is default, regularly used for training\n",
+ " print(name, param.shape, param.dtype)"
+ ],
+ "execution_count": 39,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32\n",
+ "ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "665352e2-208f-4951-8601-282d036b0e4e",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "OSTy4SU8cspH"
+ },
+ "source": [
+ "### Quantization\n",
+ "\n",
+ "As you can see above, the normal EBC contains embedding table weights as FP32 precision (32 bits for each weight). Here, we will use the TorchRec inference library to quantize the embedding weights of the model to INT8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "796919b4-f9dd-4d14-a40e-f20668c8257b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157499,
+ "executionStopTime": 1726000157696,
+ "serverExecutionDuration": 14.22468572855,
+ "requestMsgId": "796919b4-f9dd-4d14-a40e-f20668c8257b",
+ "customOutput": null,
+ "outputsInitialized": true,
+ "output": {
+ "id": "560049189691202"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "oV-KPRqDcspH",
+ "outputId": "e91220a1-a26c-4f2a-e7c8-e91a3d56d8dc"
+ },
+ "source": [
+ "from torch import quantization as quant\n",
+ "from torchrec.modules.embedding_configs import QuantConfig\n",
+ "from torchrec.quant.embedding_modules import (\n",
+ " EmbeddingBagCollection as QuantEmbeddingBagCollection,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "quant_dtype = torch.int8\n",
+ "\n",
+ "\n",
+ "qconfig = QuantConfig(\n",
+ " # dtype of the result of the embedding lookup, post activation\n",
+ " # torch.float generally for compatability with rest of the model\n",
+ " # as rest of the model here usually isn't quantized\n",
+ " activation=quant.PlaceholderObserver.with_args(dtype=torch.float),\n",
+ " # quantized type for embedding weights, aka parameters to actually quantize\n",
+ " weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),\n",
+ ")\n",
+ "qconfig_spec = {\n",
+ " # Map of module type to qconfig\n",
+ " torchrec.EmbeddingBagCollection: qconfig,\n",
+ "}\n",
+ "mapping = {\n",
+ " # Map of module type to quantized module type\n",
+ " torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,\n",
+ "}\n",
+ "\n",
+ "\n",
+ "module = InferenceModule(ebc)\n",
+ "\n",
+ "# Quantize the module\n",
+ "qebc = quant.quantize_dynamic(\n",
+ " module,\n",
+ " qconfig_spec=qconfig_spec,\n",
+ " mapping=mapping,\n",
+ " inplace=False,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "print(f\"Quantized EBC: {qebc}\")"
+ ],
+ "execution_count": 40,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Quantized EBC: InferenceModule(\n",
+ " (ebc_): QuantizedEmbeddingBagCollection(\n",
+ " (_kjt_to_jt_dict): ComputeKJTToJTDict()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157700,
+ "executionStopTime": 1726000157862,
+ "serverExecutionDuration": 4.0535479784012,
+ "requestMsgId": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "id": "fAztesVacspI"
+ },
+ "source": [
+ "kjt = kjt.to(\"cpu\")"
+ ],
+ "execution_count": 41,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000157865,
+ "executionStopTime": 1726000158060,
+ "serverExecutionDuration": 9.1104581952095,
+ "requestMsgId": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "434299789062153"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Wnpwa0TmcspI",
+ "outputId": "88007466-88ce-4f6b-b7e8-22d042e5378b"
+ },
+ "source": [
+ "qebc(kjt)"
+ ],
+ "execution_count": 42,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 42
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "99559efa-baaa-4de1-91d3-7899f87fe659",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158063,
+ "executionStopTime": 1726000158228,
+ "serverExecutionDuration": 3.4465603530407,
+ "requestMsgId": "99559efa-baaa-4de1-91d3-7899f87fe659",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "499581679596627"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UUs5fXNncspI",
+ "outputId": "79d814ed-7772-4bc3-ba84-f5f0d0a45e36"
+ },
+ "source": [
+ "# Once quantized, goes from parameters -> buffers, as no longer trainable\n",
+ "for name, buffer in qebc.named_buffers():\n",
+ " # The shapes of the tables should be the same but the dtype should be int8 now\n",
+ " # post quantization\n",
+ " print(name, buffer.shape, buffer.dtype)"
+ ],
+ "execution_count": 43,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8\n",
+ "ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "2b1a9c89-b921-4a35-9f64-0c63b09a2579",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "fdM7UihocspI"
+ },
+ "source": [
+ "### Shard\n",
+ "\n",
+ "Here we perform sharding of the TorchRec quantized model. This is to ensure we are using the performant module through FBGEMM TBE. Here we are using one device to be consistent with training (1 TBE)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158234,
+ "executionStopTime": 1726000158552,
+ "serverExecutionDuration": 108.51271077991,
+ "requestMsgId": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "882684747065056"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "mha4FntncspI",
+ "outputId": "40a5557c-531f-4726-f06a-5f73546a8fe0"
+ },
+ "source": [
+ "from torchrec import distributed as trec_dist\n",
+ "from torchrec.distributed.shard import _shard_modules\n",
+ "\n",
+ "\n",
+ "sharded_qebc = _shard_modules(\n",
+ " module=qebc,\n",
+ " device=torch.device(\"cpu\"),\n",
+ " env=trec_dist.ShardingEnv.from_local(\n",
+ " 1,\n",
+ " 0,\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "\n",
+ "print(f\"Sharded Quantized EBC: {sharded_qebc}\")"
+ ],
+ "execution_count": 44,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sharded Quantized EBC: InferenceModule(\n",
+ " (ebc_): ShardedQuantEmbeddingBagCollection(\n",
+ " (lookups): \n",
+ " InferGroupedPooledEmbeddingsLookup()\n",
+ " (_output_dists): ModuleList()\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (product_table): Module()\n",
+ " (user_table): Module()\n",
+ " )\n",
+ " (_input_dist_module): ShardedQuantEbcInputDist()\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000158555,
+ "executionStopTime": 1726000159111,
+ "serverExecutionDuration": 345.11629864573,
+ "requestMsgId": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "876807203893705"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0iBD90t3cspI",
+ "outputId": "86a17997-aa50-4427-cf2b-55f4c0aef456"
+ },
+ "source": [
+ "sharded_qebc(kjt)"
+ ],
+ "execution_count": 45,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 45
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "897037bb-9d81-4a33-aea1-de1691217d41",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "08ue1zeVcspI"
+ },
+ "source": [
+ "### Compilation\n",
+ "Now we have the optimized eager TorchRec inference model. The next step is to ensure that this model is loadable in C++, as currently it is only runnable in a Python runtime.\n",
+ "\n",
+ "The recommended method of compilation at Meta is two fold: [torch.fx tracing](https://pytorch.org/docs/stable/fx.html) (generate intermediate representation of model) and converting the result to TorchScript, where TorchScript is C++ compatible."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159118,
+ "executionStopTime": 1726000159308,
+ "serverExecutionDuration": 28.788283467293,
+ "requestMsgId": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "491668137118498"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SRzo1jljcspI",
+ "outputId": "e0f94cf0-c5cf-4480-b0d9-d5dd2abb459b"
+ },
+ "source": [
+ "from torchrec.fx import Tracer\n",
+ "\n",
+ "\n",
+ "tracer = Tracer(leaf_modules=[\"IntNBitTableBatchedEmbeddingBagsCodegen\"])\n",
+ "\n",
+ "graph = tracer.trace(sharded_qebc)\n",
+ "gm = torch.fx.GraphModule(sharded_qebc, graph)\n",
+ "\n",
+ "print(\"Graph Module Created!\")"
+ ],
+ "execution_count": 46,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Graph Module Created!\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "909178d6-4dae-45da-9c39-6827019f53a3",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159312,
+ "executionStopTime": 1726000159490,
+ "serverExecutionDuration": 2.2248737514019,
+ "requestMsgId": "909178d6-4dae-45da-9c39-6827019f53a3",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1555501808508272"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NsgzbUdHcspI",
+ "outputId": "c5b67630-19c7-46df-c0ab-216d24309603"
+ },
+ "source": [
+ "print(gm.code)"
+ ],
+ "execution_count": 47,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embeddingbag_flatten_feature_lengths\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_fx_utils__fx_marker\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embedding_kernel__unwrap_kjt\")\n",
+ "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference\")\n",
+ "\n",
+ "def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):\n",
+ " flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None\n",
+ " _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths)\n",
+ " split = flatten_feature_lengths.split([2])\n",
+ " getitem = split[0]; split = None\n",
+ " to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None\n",
+ " _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = None\n",
+ " _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None\n",
+ " getitem_1 = _unwrap_kjt[0]\n",
+ " getitem_2 = _unwrap_kjt[1]\n",
+ " getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = None\n",
+ " _tensor_constant0 = self._tensor_constant0\n",
+ " _tensor_constant1 = self._tensor_constant1\n",
+ " bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None); _tensor_constant0 = _tensor_constant1 = None\n",
+ " _tensor_constant2 = self._tensor_constant2\n",
+ " _tensor_constant3 = self._tensor_constant3\n",
+ " _tensor_constant4 = self._tensor_constant4\n",
+ " _tensor_constant5 = self._tensor_constant5\n",
+ " _tensor_constant6 = self._tensor_constant6\n",
+ " _tensor_constant7 = self._tensor_constant7\n",
+ " _tensor_constant8 = self._tensor_constant8\n",
+ " _tensor_constant9 = self._tensor_constant9\n",
+ " int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None\n",
+ " embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None\n",
+ " to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None\n",
+ " keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None\n",
+ " return keyed_tensor\n",
+ " \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000159494,
+ "executionStopTime": 1726000160206,
+ "serverExecutionDuration": 540.64276814461,
+ "requestMsgId": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "978016470760577"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "CjjJLc6pcspI",
+ "outputId": "be3a9486-e4b5-43f6-aed0-711e827a0040"
+ },
+ "source": [
+ "scripted_gm = torch.jit.script(gm)\n",
+ "print(\"Scripted Graph Module Created!\")"
+ ],
+ "execution_count": 48,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Scripted Graph Module Created!\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "originalKey": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
+ "showInput": true,
+ "customInput": null,
+ "language": "python",
+ "executionStartTime": 1726000160212,
+ "executionStopTime": 1726000160395,
+ "serverExecutionDuration": 2.8529539704323,
+ "requestMsgId": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
+ "outputsInitialized": true,
+ "customOutput": null,
+ "output": {
+ "id": "1020643789855657"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BWKPRaI3cspI",
+ "outputId": "273181a2-7c91-4167-e814-4a07b51c6b10"
+ },
+ "source": [
+ "print(scripted_gm.code)"
+ ],
+ "execution_count": 49,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "def forward(self,\n",
+ " kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:\n",
+ " _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths\n",
+ " _1 = __torch__.torchrec.fx.utils._fx_marker\n",
+ " _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt\n",
+ " _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference\n",
+ " flatten_feature_lengths = _0(kjt, )\n",
+ " _fx_marker = _1(\"KJT_ONE_TO_ALL_FORWARD_BEGIN\", flatten_feature_lengths, )\n",
+ " split = (flatten_feature_lengths).split([2], )\n",
+ " getitem = split[0]\n",
+ " to = (getitem).to(torch.device(\"cuda\", 0), True, None, )\n",
+ " _fx_marker_1 = _1(\"KJT_ONE_TO_ALL_FORWARD_END\", flatten_feature_lengths, )\n",
+ " _unwrap_kjt = _2(to, )\n",
+ " getitem_1 = (_unwrap_kjt)[0]\n",
+ " getitem_2 = (_unwrap_kjt)[1]\n",
+ " _tensor_constant0 = self._tensor_constant0\n",
+ " _tensor_constant1 = self._tensor_constant1\n",
+ " ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)\n",
+ " _tensor_constant2 = self._tensor_constant2\n",
+ " _tensor_constant3 = self._tensor_constant3\n",
+ " _tensor_constant4 = self._tensor_constant4\n",
+ " _tensor_constant5 = self._tensor_constant5\n",
+ " _tensor_constant6 = self._tensor_constant6\n",
+ " _tensor_constant7 = self._tensor_constant7\n",
+ " _tensor_constant8 = self._tensor_constant8\n",
+ " _tensor_constant9 = self._tensor_constant9\n",
+ " int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)\n",
+ " _4 = [int_nbit_split_embedding_codegen_lookup_function]\n",
+ " embeddings_cat_empty_rank_handle_inference = _3(_4, 1, \"cuda:0\", 6, )\n",
+ " to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device(\"cpu\"))\n",
+ " _5 = [\"product\", \"user\"]\n",
+ " _6 = [64, 64]\n",
+ " keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)\n",
+ " _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )\n",
+ " return keyed_tensor\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "originalKey": "9a1dda10-b9cf-4d9f-b068-51ae3ce3ffc1",
+ "showInput": false,
+ "customInput": null,
+ "language": "markdown",
+ "outputsInitialized": false,
+ "id": "DQiGRYOgcspI"
+ },
+ "source": [
+ "## Congrats!\n",
+ "\n",
+ "You have now gone from training a distributed RecSys model all the way to making it inference ready. https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a full example of how to load a TorchRec TorchScript model into C++ for inference."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ebXfh7oW9fHH",
+ "originalKey": "4ca6a593-9ac9-4e2f-bc9a-8c8a1887ad41",
+ "outputsInitialized": false,
+ "language": "markdown",
+ "showInput": false
+ },
+ "source": [
+ "## More resources\n",
+ "For more information, please see our [dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) example, which includes multinode training on the criteo terabyte dataset, using Meta’s [DLRM](https://arxiv.org/abs/1906.00091)."
+ ]
+ }
+ ]
+}
diff --git a/Torchrec_Introduction.ipynb b/Torchrec_Introduction.ipynb
index 5da052875..970fbf8cc 100644
--- a/Torchrec_Introduction.ipynb
+++ b/Torchrec_Introduction.ipynb
@@ -19,44 +19,11 @@
"source": [
"## **Installation**\n",
"Requirements:\n",
- "- python >= 3.7\n",
+ "- python >= 3.9\n",
"\n",
"We highly recommend CUDA when using TorchRec. If using CUDA:\n",
- "- cuda >= 11.0\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BB2K68OYUJ_t"
- },
- "outputs": [],
- "source": [
- "# install conda to make installying pytorch with cudatoolkit 11.3 easier. \n",
- "!wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "sFYvP95xaAER"
- },
- "outputs": [],
- "source": [
- "# install pytorch with cudatoolkit 11.6\n",
- "!conda install pytorch pytorch-cuda=11.6 -c pytorch-nightly -c nvidia -y"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7iY7Uv11mJYK"
- },
- "source": [
+ "- cuda >= 11.8\n",
+ "\n",
"Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run "
]
},
@@ -64,53 +31,14 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "tUnIw-ZREQJy"
- },
- "outputs": [],
- "source": [
- "# install torchrec\n",
- "!pip3 install torchrec-nightly"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "b6EHgotRXFQh"
- },
- "source": [
- "The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the colab runtime**. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_P45pDteRcWj"
- },
- "outputs": [],
- "source": [
- "!cp /usr/local/lib/lib* /usr/lib/"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "n5_X2WOAYG3c"
- },
- "source": [
- "\\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "8cktNrh8R9rC"
+ "id": "sFYvP95xaAER"
},
"outputs": [],
"source": [
- "import sys\n",
- "sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages']"
+ "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
+ "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/nightly/cu121\n",
+ "!pip3 install torchmetrics==1.0.3\n",
+ "!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121"
]
},
{
@@ -236,11 +164,11 @@
"metadata": {},
"outputs": [],
"source": [
- "from torchrec.optim.apply_overlapped_optimizer import apply_overlapped_optimizer\n",
+ "from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\n",
"\n",
- "apply_overlapped_optimizer(\n",
- " ebc.parameters(),\n",
- " optimizer_type=torch.optim.SGD,\n",
+ "apply_optimizer_in_backward(\n",
+ " optimizer_class=torch.optim.SGD,\n",
+ " params=ebc.parameters(),\n",
" optimizer_kwargs={\"lr\": 0.02},\n",
")"
]
@@ -625,35 +553,46 @@
}
],
"metadata": {
- "accelerator": "GPU",
- "colab": {
- "background_execution": "on",
- "collapsed_sections": [],
- "machine_shape": "hm",
- "name": "Torchrec Introduction.ipynb",
- "provenance": []
- },
- "interpreter": {
- "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
- },
- "kernelspec": {
- "display_name": "Python 3.7.13 ('torchrec': conda)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
+ "custom": {
+ "cells": [],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "background_execution": "on",
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "name": "Torchrec Introduction.ipynb",
+ "provenance": []
+ },
+ "fileHeader": "",
+ "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
+ "interpreter": {
+ "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
+ },
+ "isAdHoc": false,
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('torchrec': conda)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
},
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.13"
- }
+ "nbformat": 4,
+ "nbformat_minor": 0
+ },
+ "indentAmount": 2
},
"nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 2
}
diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
index b0f6a4bde..23001de25 100644
--- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp
@@ -12,7 +12,7 @@
namespace torchrec {
void BM_MixedLFULRUStrategy(benchmark::State& state) {
size_t num_ext_values = state.range(0);
- std::vector ext_values(num_ext_values);
+ std::vector ext_values(num_ext_values);
MixedLFULRUStrategy strategy;
for (auto& v : ext_values) {
diff --git a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
index 1f4183d74..f8e2a0cd2 100644
--- a/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp
@@ -18,13 +18,13 @@ class RecordIterator {
public:
RecordIterator(Container::const_iterator begin, Container::const_iterator end)
: begin_(begin), end_(end) {}
- std::optional> operator()() {
+ std::optional operator()() {
if (begin_ == end_) {
return std::nullopt;
}
- TransformerRecord record{};
- record.global_id_ = next_global_id_++;
- record.lxu_record_ = *reinterpret_cast(&(*begin_++));
+ record_t record{};
+ record.global_id = next_global_id_++;
+ record.lxu_record = *reinterpret_cast(&(*begin_++));
return record;
}
@@ -52,8 +52,8 @@ class RandomizeMixedLXUSet {
std::uniform_int_distribution time_dist(0, max_time - 1);
for (size_t i = 0; i < n; ++i) {
MixedLFULRUStrategy::Record record{};
- record.freq_power_ = freq_dist(engine) + min_freq;
- record.time_ = time_dist(engine);
+ record.freq_power = freq_dist(engine) + min_freq;
+ record.time = time_dist(engine);
records_.emplace_back(record);
}
}
@@ -68,8 +68,9 @@ class RandomizeMixedLXUSet {
void BM_MixedLFULRUStrategyEvict(benchmark::State& state) {
RandomizeMixedLXUSet lxuSet(state.range(0), state.range(1), state.range(2));
+ MixedLFULRUStrategy strategy;
for (auto _ : state) {
- MixedLFULRUStrategy::evict(lxuSet.Iterator(), state.range(3));
+ strategy.evict(lxuSet.Iterator(), state.range(3));
}
}
diff --git a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
index 17faadb67..ae66fbfcc 100644
--- a/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
+++ b/benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp
@@ -13,8 +13,7 @@
namespace torchrec {
static void BM_NaiveIDTransformer(benchmark::State& state) {
- using Tag = int32_t;
- NaiveIDTransformer transformer(2e8);
+ NaiveIDTransformer transformer(2e8);
torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
torch::Tensor cache_ids = torch::empty_like(global_ids);
for (auto _ : state) {
diff --git a/benchmarks/ebc_benchmarks.py b/benchmarks/ebc_benchmarks.py
index 082f05bc0..98a0ffcbd 100644
--- a/benchmarks/ebc_benchmarks.py
+++ b/benchmarks/ebc_benchmarks.py
@@ -5,13 +5,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import argparse
import sys
from typing import List, Tuple
+# pyre-fixme[21]: Could not find module `ebc_benchmarks_utils`.
+import ebc_benchmarks_utils
import torch
-from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation
-from torchrec.github.benchmarks import ebc_benchmarks_utils
+from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection
@@ -251,5 +254,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
return parser.parse_args(argv)
-if __name__ == "__main__":
+def invoke_main() -> None:
main(sys.argv[1:])
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/benchmarks/ebc_benchmarks_utils.py b/benchmarks/ebc_benchmarks_utils.py
index 74f7f67d8..b15ec2b66 100644
--- a/benchmarks/ebc_benchmarks_utils.py
+++ b/benchmarks/ebc_benchmarks_utils.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import time
from typing import Dict, List, Optional, Tuple
diff --git a/contrib/dynamic_embedding/README.md b/contrib/dynamic_embedding/README.md
new file mode 100644
index 000000000..975963f35
--- /dev/null
+++ b/contrib/dynamic_embedding/README.md
@@ -0,0 +1,80 @@
+# TorchRec Dynamic Embedding
+
+This folder contains the extension to support dynamic embedding for torchrec. Specifically, this extension enable torchrec to attach an external PS, so that when local GPU embedding is not big enough, we could pull/evict embeddings from/to the PS.
+
+## Installation
+
+After install torchrec, please clone the torchrec repo and manually install the dynamic embedding:
+
+```bash
+git clone git@github.com:pytorch/torchrec.git
+cd contrib/dynamic_embedding
+python setup.py install
+```
+
+And the dynamic embedding will be installed as a separate package named `torchrec_dynamic_embedding`.
+
+Notice that for C++20 supports we recommend gcc version higher or equal to 10. Conda users could install the lastest gcc utilities with:
+
+```bash
+conda install gxx_linux-64
+```
+
+We incorporate `gtest` for the C++ code and use unittest for the python APIs. The tests make sure that the implementation does not have any precision loss. Please turn on the `TDE_WITH_TESTING` in `setup.py` to run tests. Note that for the python test, one needs to set the environment variable `TDE_MEMORY_IO_PATH` to the path of the compiled `memory_io.so`.
+
+## Usage
+
+The dynamic embedding extension has only one api, `tde.wrap`, when wrapping the dataloader and model with it, we will automatically pipeline the data processing and model training. And example of `tde.wrap` is:
+
+```python
+import torchrec_dynamic_embedding as tde
+
+class Model(nn.Module):
+ def __init__(self, config1, config2):
+ super().__init__()
+ self.emb1 = EmbeddingCollection(tables=config1, device=torch.device("meta"))
+ self.emb2 = EmbeddingCollection(tables=config2, device=torch.device("meta"))
+ ...
+
+ def forward(self, kjt1, kjt2):
+ ...
+
+m = Model(config1, config2)
+m = DistributedModelParallel(m)
+dataloader = tde.wrap(
+ "redis://127.0.0.1:6379/?prefix=model",
+ dataloader,
+ m,
+ # configs of the embedding collections in the model
+ { "emb1": config1, "emb2": config2 })
+
+for label, kjt1, kjt2 in dataloader:
+ output = m(kjt1, kjt2)
+ ...
+```
+
+The internal of `tde.wrap` is in `src/torchrec_dynamic_embedding/dataloader.py`, where we will attach hooks to the embedding tensor as well as creating the dataloader thread for pipelining.
+
+## Custom PS Extension
+
+The dynamic embedding extension supports connecting with your PS cluster. To write your own PS extension, you need to create an dynamic library (`*.so`) with these 4 functions and 1 variable:
+
+```c++
+const char* IO_type = "your-ps";
+
+void* IO_Initialize(const char* cfg);
+
+void IO_Finalize(void* instance);
+
+void IO_Pull(void* instance, IOPullParameter cfg);
+
+void IO_Push(void* instance, IOPushParameter cfg);
+```
+
+And then use the following python API to register it:
+
+```python
+torch.ops.tde.register_io(so_path)
+```
+
+After that, you could use your own PS extension by passing the corresponding URL into `tde.wrap`, where the protocol name would be the `IO_type` and the string after `"://"` will be passed to `IO_Finalize` (`"type://cfg"`).
diff --git a/contrib/dynamic_embedding/setup.py b/contrib/dynamic_embedding/setup.py
index b4a6e968d..c5e269200 100644
--- a/contrib/dynamic_embedding/setup.py
+++ b/contrib/dynamic_embedding/setup.py
@@ -1,7 +1,16 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
import os
import sys
import torch
+from setuptools import find_packages
from skbuild import setup
@@ -40,7 +49,7 @@
setup(
name="torchrec_dynamic_embedding",
package_dir={"": "src"},
- packages=["torchrec_dynamic_embedding"],
+ packages=find_packages("src"),
cmake_args=[
"-DCMAKE_BUILD_TYPE=Release",
f"-DTDE_TORCH_BASE_DIR={os.path.dirname(torch.__file__)}",
diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
index 6d6a328f3..bada56602 100644
--- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
+++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -30,7 +38,7 @@ class CachelineIDTransformerIterator {
continue;
}
TransformerRecord result{};
- result.global_id_ = -record.global_id_not_;
+ result.global_id_ = ~record.global_id_not_;
result.cache_id_ = record.cache_id_;
result.lxu_record_ = record.lxu_record_;
return result;
diff --git a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
index 5cf36a3b7..6e28fe170 100644
--- a/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
+++ b/contrib/dynamic_embedding/src/tde/details/cacheline_id_transformer_impl.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -64,38 +72,40 @@ inline bool CachelineIDTransformer<
int64_t global_id_not = ~global_id;
CacheValue* group_begin = &cache_values_[group_id * group_size_];
int64_t k = 0;
+ int64_t empty_slot = -1;
+ int64_t cache_id = -1;
for (; k < group_size_; k++, intra_id++) {
intra_id %= group_size_;
auto& cache_value = group_begin[intra_id];
// tricky but fast :p
int64_t xor_value = cache_value.global_id_not_ ^ global_id_not;
- if (xor_value > 0) {
- continue;
- }
- int64_t cache_id;
if (xor_value == 0) { // found
cache_id = cache_value.cache_id_;
cache_value.lxu_record_ =
update(cache_value.lxu_record_, global_id, cache_id);
- } else { // empty slot
+ break;
+ } else if (xor_value < 0 && empty_slot < 0) { // empty slot
+ empty_slot = intra_id;
+ }
+ }
+ if (cache_id < 0) {
+ if (empty_slot >= 0) {
// The transformer is full.
if (C10_UNLIKELY(bitmap_.Full())) {
return false;
}
+ auto& cache_value = group_begin[empty_slot];
cache_id = bitmap_.NextFreeBit();
cache_value.global_id_not_ = global_id_not;
cache_value.cache_id_ = cache_id;
cache_value.lxu_record_ = update(std::nullopt, global_id, cache_id);
fetch(global_id, cache_id);
+ } else {
+ return false;
}
- cache_ids[i] = cache_id;
- break;
- }
-
- if (k == group_size_) {
- return false;
}
+ cache_ids[i] = cache_id;
}
return true;
}
@@ -115,14 +125,14 @@ inline void CachelineIDTransformer<
for (const int64_t global_id : global_ids) {
auto [group_id, intra_id] = FindGroupIndex(global_id);
+ int64_t global_id_not = ~global_id;
for (int64_t k = 0; k < group_size_; k++) {
int64_t offset = group_id * group_size_ + (intra_id + k) % group_size_;
auto& cache_value = cache_values_[offset];
// tricky but fast :p
- int64_t global_id_not = ~global_id;
int64_t xor_value = global_id_not ^ cache_value.global_id_not_;
if (xor_value < 0) { // not exist
- break;
+ continue;
} else if (xor_value == 0) { // found slot
bitmap_.FreeBit(cache_value.cache_id_);
cache_value.global_id_not_ = 0;
diff --git a/contrib/dynamic_embedding/src/tde/ps.cpp b/contrib/dynamic_embedding/src/tde/ps.cpp
index 4e3bdecb9..e9fcb3c25 100644
--- a/contrib/dynamic_embedding/src/tde/ps.cpp
+++ b/contrib/dynamic_embedding/src/tde/ps.cpp
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#include "tde/ps.h"
#include "tde/details/io.h"
@@ -17,9 +25,13 @@ c10::intrusive_ptr PS::Fetch(
if (cache_ids_to_fetch_or_evict_.empty()) {
return c10::make_intrusive(time, c10::intrusive_ptr());
}
- fetch_notifications_.emplace_back(time, c10::make_intrusive());
- c10::intrusive_ptr notification =
- fetch_notifications_.back().second;
+ c10::intrusive_ptr notification;
+ {
+ std::unique_lock lock_fetch(fetch_notifications_mutex_);
+ fetch_notifications_.emplace_back(
+ time, c10::make_intrusive());
+ notification = fetch_notifications_.back().second;
+ }
uint32_t num_os_ids = os_ids_.size();
io_.Pull(
table_name_,
@@ -97,9 +109,7 @@ void PS::Evict(torch::Tensor ids_to_evict) {
notification.Done();
// The shared data for all chunks.
std::vector offsets;
- offsets.reserve(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1);
- std::vector data(
- num_ids_per_chunk_ * num_os_ids * col_ids.size() * col_size_);
+ offsets.resize(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1);
for (uint32_t i = 0; i < num_ids_to_fetch; i += num_ids_per_chunk_) {
uint32_t num_ids_in_chunk = std::min(
@@ -107,22 +117,22 @@ void PS::Evict(torch::Tensor ids_to_evict) {
uint32_t data_size = num_ids_in_chunk * num_os_ids * col_ids.size();
uint32_t offsets_size = num_ids_in_chunk * num_os_ids * col_ids.size() + 1;
- offsets.clear();
- offsets.emplace_back(0);
+ std::vector all_tensors;
for (uint32_t j = i; j < i + num_ids_in_chunk; ++j) {
int64_t cache_id = cache_ids_to_fetch_or_evict_[j];
std::vector tensors = GetTensorViews(cache_id);
- for (uint32_t k : os_ids_) {
- // this cause 2 copy. is this avoidable?
- torch::Tensor tensor = tensors[k].cpu();
- // need to change this when considering col
- memcpy(
- reinterpret_cast(data.data()) + offsets.back(),
- tensor.data_ptr(),
- tensor.numel() * tensor.element_size());
- offsets.emplace_back(
- offsets.back() + tensor.numel() * tensor.element_size());
- }
+ all_tensors.insert(all_tensors.end(), tensors.begin(), tensors.end());
+ }
+ torch::Tensor data = torch::cat(all_tensors, 0).cpu();
+ TORCH_CHECK(data.numel() == data_size * col_size_);
+
+ // to prevent the original data from being prematurely recycled
+ auto data_shared_ptr = std::make_shared(data);
+
+ offsets[0] = 0;
+ for (uint32_t j = 0; j < all_tensors.size(); ++j) {
+ offsets[j + 1] =
+ offsets[j] + all_tensors[j].numel() * all_tensors[j].element_size();
}
// waiting for the Push of last chunk finishes.
notification.Wait();
@@ -133,21 +143,30 @@ void PS::Evict(torch::Tensor ids_to_evict) {
col_ids,
os_ids_,
tcb::span{
- reinterpret_cast(data.data()), data_size * sizeof(float)},
+ reinterpret_cast(data_shared_ptr->data_ptr()),
+ data_size * sizeof(float)},
tcb::span{offsets.data(), offsets_size},
- [¬ification] { notification.Done(); });
+ [¬ification, data_shared_ptr] { notification.Done(); });
}
notification.Wait();
}
void PS::SyncFetch(int64_t time) {
- while (!fetch_notifications_.empty()) {
- auto& [t, notification] = fetch_notifications_.front();
- if (t != time && time >= 0) {
+ std::unique_lock lock(
+ fetch_notifications_mutex_, std::defer_lock);
+
+ while (true) {
+ lock.lock();
+ if (fetch_notifications_.empty() ||
+ fetch_notifications_.front().first != time && time >= 0) {
+ lock.unlock();
break;
}
- notification->Wait();
+ auto notification = fetch_notifications_.front().second;
fetch_notifications_.pop_front();
+ lock.unlock();
+
+ notification->Wait();
}
}
diff --git a/contrib/dynamic_embedding/src/tde/ps.h b/contrib/dynamic_embedding/src/tde/ps.h
index aed24fdc6..cad040f4c 100644
--- a/contrib/dynamic_embedding/src/tde/ps.h
+++ b/contrib/dynamic_embedding/src/tde/ps.h
@@ -1,3 +1,11 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
#pragma once
#include
#include
@@ -99,6 +107,7 @@ class PS : public torch::CustomClassHolder {
void Filter(const torch::Tensor& tensor);
std::mutex mu_;
+ std::mutex fetch_notifications_mutex_;
std::string table_name_;
c10::intrusive_ptr shards_;
int64_t col_size_;
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
index 1a7645fd3..96911a449 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/dataloader.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import queue
import threading
from typing import Dict, List, Union
@@ -34,6 +41,9 @@ def transform_loop(dataloader, transform_fn, out_queue, done_event):
# save memory
del transformed_data
+ if not done_event.is_set():
+ done_event.set()
+
class DataLoaderIter:
def __init__(self, dataloader, transform_fn, num_prefetch=0):
@@ -55,6 +65,8 @@ def __del__(self):
self._done_event.set()
def _get_data(self):
+ if self._done_event.is_set():
+ raise StopIteration
if not self._transform_thread.is_alive():
raise RuntimeError("Transform thread exited unexpectedly")
data, handles = self._data_queue.get()
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
index cdf9a93a9..594213334 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/__init__.py
@@ -1 +1,15 @@
-from .comm import default_group, gather_kjts
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
+from .comm import (
+ broadcast_ids_to_evict,
+ broadcast_transform_result,
+ gather_global_ids,
+ scatter_cache_ids,
+)
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
index b6da3be0a..f15f3ce18 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/distributed/comm.py
@@ -1,127 +1,98 @@
-from dataclasses import dataclass
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
from typing import List, Optional
import torch
import torch.distributed as dist
-from torchrec import KeyedJaggedTensor
-
__all__ = []
-_group = [None]
-
-
-@dataclass
-class GatherResult:
- values_list: List[List[torch.Tensor]]
- offset_per_key_list: List[List[torch.Tensor]]
-
-
-def default_group(group):
- if group is not None:
- return group
-
- if _group[0] is None:
- _group[0] = dist.new_group(backend="gloo")
-
- return _group[0]
-
-def gather_tensor_list(
- tensors: List[torch.Tensor], numel_lists: List[List[int]], dst=0, group=None
-) -> Optional[List[List[torch.Tensor]]]:
- rank = dist.get_rank()
+def gather_global_ids(global_ids: List[torch.Tensor], group):
world_size = dist.get_world_size()
- if dst == rank:
- if len(numel_lists) != world_size:
- raise ValueError("dst rank should know size of tensors on all ranks")
- else:
- if len(numel_lists) != 1:
- raise ValueError("non dst rank should pass its own tensor sizes")
-
- group = default_group(group)
- dtype = tensors[0].dtype
- device = tensors[0].device
+ rank = dist.get_rank()
- concated_tensor = torch.cat(tensors)
+ concat_global_ids = torch.cat(global_ids)
- if rank == dst:
- # gather can only accept same-size tensors.
- max_numel = max(sum(numel_list) for numel_list in numel_lists)
- concat_results = [
- torch.empty(max_numel, dtype=dtype, device=device)
- for _ in range(world_size)
- ]
- dist.gather(
- concated_tensor,
- gather_list=concat_results,
- dst=dst,
- group=group,
- async_op=False,
- )
+ concat_numel = torch.tensor(concat_global_ids.numel(), dtype=torch.int64)
+ concat_numel_list = [torch.tensor(0, dtype=torch.int64) for _ in range(world_size)]
+ dist.all_gather(concat_numel_list, concat_numel, group=group, async_op=False)
- results = []
- for i in range(world_size):
- splited_tensors = []
- offset = 0
- for numel in numel_lists[i]:
- splited_tensors.append(concat_results[i][offset : offset + numel])
- offset += numel
- results.append(splited_tensors)
+ max_numel = max(concat_numel_list)
+ concat_global_ids.resize_(max_numel)
- return results
+ if rank == 0:
+ concat_global_ids_list = [
+ torch.empty_like(concat_global_ids) for _ in range(world_size)
+ ]
+ dist.gather(concat_global_ids, concat_global_ids_list, 0, group, async_op=False)
+ return [
+ concat_global_ids_list[i][: concat_numel_list[i]] for i in range(world_size)
+ ], concat_numel_list
else:
- dist.gather(
- concated_tensor, gather_list=None, dst=dst, group=group, async_op=False
- )
-
- return None
+ dist.gather(concat_global_ids, None, 0, group, async_op=False)
+ return None, concat_numel_list
-def gather_kjts(kjts: List[KeyedJaggedTensor], dst=0, group=None) -> GatherResult:
+def scatter_cache_ids(
+ cache_ids_list: Optional[List[torch.Tensor]], concat_numel_list: List[int], group
+):
world_size = dist.get_world_size()
- if world_size == 1:
- return GatherResult(
- values_list=[[kjt.values()] for kjt in kjts],
- offset_per_key_list=[[kjt.offset_per_key()] for kjt in kjts],
- )
-
rank = dist.get_rank()
- group = default_group(group)
- offset_per_key_list = [torch.tensor(kjt.offset_per_key()) for kjt in kjts]
- values_list = [kjt.values() for kjt in kjts]
+ max_numel = max(concat_numel_list)
- offset_numel_list = [tensor.numel() for tensor in offset_per_key_list]
- values_numel_list = [tensor.numel() for tensor in values_list]
-
- if rank == dst:
- global_offset_numel_list = [offset_numel_list] * world_size
- offset_results = gather_tensor_list(
- offset_per_key_list,
- numel_lists=global_offset_numel_list,
- dst=dst,
- group=group,
- )
-
- global_values_numel_list = [
- [offset[-1].item() for offset in offsets] for offsets in offset_results
+ concat_cache_ids = torch.empty(max_numel, dtype=torch.int64)
+ if rank == 0:
+ concat_cache_ids_list = [concat_cache_ids] + [
+ cache_ids.resize_(max_numel)
+ for cache_ids in cache_ids_list[-world_size + 1 :]
]
-
- values_result = gather_tensor_list(
- values_list, numel_lists=global_values_numel_list, dst=dst, group=group
- )
-
- return GatherResult(
- values_list=values_result, offset_per_key_list=offset_results
- )
+ assert len(concat_cache_ids_list) == world_size
+ dist.scatter(concat_cache_ids, concat_cache_ids_list, group=group)
else:
- gather_tensor_list(
- offset_per_key_list, numel_lists=[offset_numel_list], dst=dst, group=group
+ dist.scatter(concat_cache_ids, None, group=group)
+ offset = 0
+ for cache_ids in cache_ids_list:
+ cache_ids[:] = concat_cache_ids[offset : offset + cache_ids.numel()]
+ offset += cache_ids.numel()
+
+
+def broadcast_transform_result(
+ success: bool, ids_to_fetch: Optional[torch.Tensor], group
+):
+ if dist.get_rank() == 0:
+ success_and_numel = torch.tensor(
+ [1 if success else 0, ids_to_fetch.numel()], dtype=torch.int64
)
+ dist.broadcast(success_and_numel, src=0, group=group)
+ else:
+ success_and_numel = torch.tensor([0, 0], dtype=torch.int64)
+ dist.broadcast(success_and_numel, src=0, group=group)
+ success, numel = success_and_numel.tolist()
+ success = success != 0
+ ids_to_fetch = torch.empty((numel // 2, 2), dtype=torch.int64)
+
+ if ids_to_fetch.numel() > 0:
+ dist.broadcast(ids_to_fetch, src=0, group=group)
+ return success, ids_to_fetch
- gather_tensor_list(
- values_list, numel_lists=[values_numel_list], dst=dst, group=group
- )
- return None
+def broadcast_ids_to_evict(ids, group):
+ if dist.get_rank() == 0:
+ numel = torch.tensor(ids.numel(), dtype=torch.int64)
+ dist.broadcast(numel, src=0, group=group)
+ else:
+ numel = torch.tensor(0, dtype=torch.int64)
+ dist.broadcast(numel, src=0, group=group)
+ numel = numel.item()
+ ids = torch.empty((numel // 2, 2), dtype=torch.int64)
+
+ if numel > 0:
+ dist.broadcast(ids, src=0, group=group)
+ return ids
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
index 6d7169183..735412d35 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import json
import os
@@ -33,9 +40,10 @@ def transform(self, global_ids: TensorList, cache_ids: TensorList, time: int):
"""
Transform `global_ids` and store the results in `cache_ids`.
"""
- return self._transformer.transform(
+ result = self._transformer.transform(
global_ids.tensor_list, cache_ids.tensor_list, time
)
+ return result.success, result.ids_to_fetch
def evict(self, num_to_evict):
"""
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
index ced6780df..1132d6be6 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_collection.py
@@ -1,8 +1,24 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/usr/bin/env python3
+
from typing import List, Tuple, Union
import torch
+import torch.distributed as dist
from torchrec import EmbeddingBagConfig, EmbeddingConfig, KeyedJaggedTensor
+from .distributed import (
+ broadcast_ids_to_evict,
+ broadcast_transform_result,
+ gather_global_ids,
+ scatter_cache_ids,
+)
from .id_transformer import IDTransformer, TensorList
from .ps import PSCollection
@@ -27,7 +43,7 @@ def __init__(
tables: list of `Embedding(Bag)Config` or `EmbeddingBagConfig` one passed to
`Embedding(Bag)Collection`.
eviction_config: config of the eviction strategy for IDTransformers.
- transformer_config: config of the transform strategy for IDTransformers.
+ transform_config: config of the transform strategy for IDTransformers.
ps_collection: `PSCollection` of the collection, if `None`, won't do eviction or fetch.
By default, IDTransformerCollection will evict half the ids when full.
"""
@@ -46,19 +62,61 @@ def __init__(
for feature_name in config.feature_names:
if feature_name in feature_names:
raise ValueError(f"Shared feature not allowed yet.")
- self._transformers.append(
- IDTransformer(
+ # only rank 0 will have the id transformer
+ # and other ranks will gather their to rank 0.
+ if dist.get_rank() == 0:
+ transformer = IDTransformer(
num_embedding=config.num_embeddings,
eviction_config=eviction_config,
transform_config=transform_config,
)
- )
+ else:
+ transformer = None
+ self._transformers.append(transformer)
self._feature_names: List[List[str]] = [
config.feature_names for config in tables
]
self._ever_evicted = False
self._time = 0
+ if dist.get_world_size() > 1:
+ self._pg = dist.new_group(backend="gloo")
+ self._stream = torch.cuda.Stream()
+
+ def _transform(
+ self, transformer, global_ids: List[torch.Tensor], cache_ids: List[torch.Tensor]
+ ):
+ with torch.cuda.stream(self._stream):
+ total_numel = sum([tensor.numel() for tensor in global_ids])
+ if total_numel > 1e6:
+ all_tensor = torch.cat(global_ids).to("cuda:0")
+ unique_all_tensor, index = torch.unique(all_tensor, return_inverse=True)
+ unique_all_tensor = unique_all_tensor.to("cpu")
+ all_cache = torch.empty_like(unique_all_tensor)
+ success, ids_to_fetch = transformer.transform(
+ TensorList([unique_all_tensor]),
+ TensorList([all_cache]),
+ self._time,
+ )
+ del all_tensor
+ all_tensor = torch.take(all_cache.to("cuda:0"), index)
+ offset = 0
+ for tensor in cache_ids:
+ numel = tensor.numel()
+ tensor.copy_(all_tensor[offset : offset + numel])
+ offset += numel
+ assert (
+ total_numel == offset
+ ), f"total_numel not equal offset, {total_numel} vs {offset}"
+ else:
+ # broadcast result
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
+ self._time,
+ )
+ return success, ids_to_fetch
+
def transform(
self, global_features: KeyedJaggedTensor
) -> Tuple[KeyedJaggedTensor, List[torch.classes.tde.FetchHandle]]:
@@ -92,47 +150,128 @@ def transform(
for idx in feature_indices
]
- result = transformer.transform(
- TensorList(global_ids), TensorList(cache_ids), self._time
- )
- if self._ps_collection is not None:
- table_name = self._table_names[i]
- ps = self._ps_collection[table_name]
- if result.ids_to_fetch.numel() > 0:
- handle = ps.fetch(
- result.ids_to_fetch,
- self._time,
- self._ever_evicted,
- self._configs[i].get_weight_init_min(),
- self._configs[i].get_weight_init_max(),
- )
- fetch_handles.append(handle)
- if not result.success:
- # TODO(zilinzhu): make this configurable
- ids_to_evict = transformer.evict(transformer._num_embedding // 2)
- ps.evict(ids_to_evict)
- self._ever_evicted = True
-
- # retry after eviction.
- result = transformer.transform(
- TensorList(global_ids), TensorList(cache_ids), self._time
+ if dist.get_world_size() > 1:
+ concat_global_ids, concat_numel_list = gather_global_ids(
+ global_ids, self._pg
+ )
+ if dist.get_rank() == 0:
+ global_ids = global_ids + concat_global_ids[1:]
+ cache_ids = cache_ids + [
+ torch.empty_like(tensor) for tensor in concat_global_ids[1:]
+ ]
+
+ success, ids_to_fetch = self._transform(
+ transformer, global_ids, cache_ids
)
- if not result.success:
- raise RuntimeError(
- "Failed to transform global ids after eviction. "
- f"Maybe the num_embedding of table {table_name} is too small?"
+ else:
+ success, ids_to_fetch = True, None
+ success, ids_to_fetch = broadcast_transform_result(
+ success, ids_to_fetch, self._pg
+ )
+
+ if self._ps_collection is not None:
+ table_name = self._table_names[i]
+ ps = self._ps_collection[table_name]
+ if ids_to_fetch.numel() > 0:
+ handle = ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
)
- if result.ids_to_fetch is not None:
- fetch_handles.append(
- ps.fetch(
- result.ids_to_fetch,
+ fetch_handles.append(handle)
+ if not success:
+ # TODO(zilinzhu): make this configurable
+ # broadcast ids_to_evict
+ if dist.get_rank() == 0:
+ ids_to_evict = transformer.evict(
+ transformer._num_embedding // 2
+ )
+ else:
+ ids_to_evict = None
+ ids_to_evict = broadcast_ids_to_evict(ids_to_evict, self._pg)
+
+ ps.evict(ids_to_evict)
+ self._ever_evicted = True
+
+ # retry after eviction.
+ # broadcast result
+ if dist.get_rank() == 0:
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
self._time,
- self._ever_evicted,
- self._configs[i].get_weight_init_min(),
- self._configs[i].get_weight_init_max(),
)
+ else:
+ success, ids_to_fetch = True, None
+ success, ids_to_fetch = broadcast_transform_result(
+ success, ids_to_fetch, self._pg
)
+ if not success:
+ raise RuntimeError(
+ "Failed to transform global ids after eviction. "
+ f"Maybe the num_embedding of table {table_name} is too small?"
+ )
+ if ids_to_fetch.numel() > 0:
+ fetch_handles.append(
+ ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ )
+
+ scatter_cache_ids(cache_ids, concat_numel_list, self._pg)
+ else:
+ success, ids_to_fetch = self._transform(
+ transformer, global_ids, cache_ids
+ )
+ if self._ps_collection is not None:
+ table_name = self._table_names[i]
+ ps = self._ps_collection[table_name]
+ if ids_to_fetch.numel() > 0:
+ handle = ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ fetch_handles.append(handle)
+ if not success:
+ # TODO(zilinzhu): make this configurable
+ ids_to_evict = transformer.evict(
+ transformer._num_embedding // 2
+ )
+ ps.evict(ids_to_evict)
+ self._ever_evicted = True
+
+ # retry after eviction.
+ success, ids_to_fetch = transformer.transform(
+ TensorList(global_ids),
+ TensorList(cache_ids),
+ self._time,
+ )
+ if not success:
+ raise RuntimeError(
+ "Failed to transform global ids after eviction. "
+ f"Maybe the num_embedding of table {table_name} is too small?"
+ )
+ if ids_to_fetch is not None:
+ fetch_handles.append(
+ ps.fetch(
+ ids_to_fetch,
+ self._time,
+ self._ever_evicted,
+ self._configs[i].get_weight_init_min(),
+ self._configs[i].get_weight_init_max(),
+ )
+ )
+
cache_values = KeyedJaggedTensor(
keys=global_features.keys(),
values=cache_values,
@@ -147,5 +286,18 @@ def save(self):
return
for i, transformer in enumerate(self._transformers):
table_name = self._table_names[i]
- ids = transformer.save()
+ if dist.get_world_size() > 1:
+ if dist.get_rank() == 0:
+ ids = transformer.save()
+ numel = torch.tensor(ids.numel())
+ dist.broadcast(numel, src=0, group=self._pg)
+ dist.broadcast(ids, src=0, group=self._pg)
+ else:
+ numel = torch.tensor(0)
+ dist.broadcast(numel, src=0, group=self._pg)
+ ids = torch.empty((numel // 2, 2), dtype=torch.int64)
+ dist.broadcast(ids, src=0, group=self._pg)
+ else:
+ ids = transformer.save()
+
self._ps_collection[table_name].evict(ids)
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
index c26b2ad4f..9c888a15f 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/id_transformer_group.py
@@ -57,8 +57,8 @@ def __init__(
configs or embeddingbag configs. The plan of `module` should contain the module path
in `configs_dict`.
eviction_config: configuration for eviction policy. Default is `{"type": "mixed_lru_lfu"}`
- transformer_config: configuration for the transformer. Default is `{"type": "naive"}`
- parallel: Whether the IDTransformerCollections will run paralell. When set to True,
+ transform_config: configuration for the transformer. Default is `{"type": "naive"}`
+ parallel: Whether the IDTransformerCollections will run parallel. When set to True,
IDTransformerGroup will start a thread for each IDTransformerCollection.
Example:
@@ -166,12 +166,6 @@ def __contains__(self, path):
"""
return path in self._id_transformer_collections
- def __contains__(self, path):
- """
- Check if there is transformer for the path.
- """
- return path in self._id_transformer_collections
-
def __del__(self):
"""
Stop the parallel threads
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
index 690e89b97..763f3e8ef 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/ps.py
@@ -1,10 +1,15 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import os
from typing import Callable, Dict, List, Tuple, Union
import torch
-import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
-from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ParameterSharding
from .tensor_list import TensorList
@@ -54,7 +59,6 @@ def __init__(
# This assumes all shard have the same column size.
col_size = shard.tensor.shape[1]
elif isinstance(tensors[0], torch.Tensor):
- tensors
shards.append(
0,
0,
diff --git a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
index fc11cce9a..591eda4c3 100644
--- a/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
+++ b/contrib/dynamic_embedding/src/torchrec_dynamic_embedding/tensor_list.py
@@ -1,4 +1,9 @@
-import json
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import os
from typing import List
diff --git a/contrib/dynamic_embedding/tests/test_integral_precision.py b/contrib/dynamic_embedding/tests/test_integral_precision.py
index 842a2b72d..965bb092c 100644
--- a/contrib/dynamic_embedding/tests/test_integral_precision.py
+++ b/contrib/dynamic_embedding/tests/test_integral_precision.py
@@ -1,3 +1,10 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import unittest
import torch
@@ -14,6 +21,7 @@
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec_dynamic_embedding.id_transformer_group import IDTransformerGroup
from utils import init_dist, register_memory_io
@@ -93,7 +101,7 @@ def get_dmp(model):
model = DMP(module=model, device=device, plan=plan, sharders=sharders)
dense_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adam(params, lr=1e-1),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
diff --git a/contrib/dynamic_embedding/tests/test_ps_collection.py b/contrib/dynamic_embedding/tests/test_ps_collection.py
index a7fffe1e0..d293749e1 100644
--- a/contrib/dynamic_embedding/tests/test_ps_collection.py
+++ b/contrib/dynamic_embedding/tests/test_ps_collection.py
@@ -1,4 +1,9 @@
-import os
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
import unittest
import torch
diff --git a/contrib/dynamic_embedding/tools/before_linux_build.sh b/contrib/dynamic_embedding/tools/before_linux_build.sh
index 7cf1d154f..527ee2128 100755
--- a/contrib/dynamic_embedding/tools/before_linux_build.sh
+++ b/contrib/dynamic_embedding/tools/before_linux_build.sh
@@ -1,13 +1,20 @@
#!/usr/bin/env bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
set -xe
distro=rhel7
arch=x86_64
-CUDA_VERSION="${CUDA_VERSION:-11.6}"
+CUDA_VERSION="${CUDA_VERSION:-11.8}"
CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $1}')
CUDA_MINOR_VERSION=$(echo "${CUDA_VERSION}" | tr '.' ' ' | awk '{print $2}')
+yum install -y yum-utils
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-$distro.repo
yum install -y \
cuda-toolkit-"${CUDA_MAJOR_VERSION}"-"${CUDA_MINOR_VERSION}" \
diff --git a/contrib/dynamic_embedding/tools/build_wheels.sh b/contrib/dynamic_embedding/tools/build_wheels.sh
index 51cd42500..3647547de 100755
--- a/contrib/dynamic_embedding/tools/build_wheels.sh
+++ b/contrib/dynamic_embedding/tools/build_wheels.sh
@@ -6,6 +6,8 @@ export CIBW_BEFORE_BUILD="tools/before_linux_build.sh"
# all kinds of CPython.
export CIBW_BUILD=${CIBW_BUILD:-"cp39-manylinux_x86_64"}
+export CIBW_MANYLINUX_X86_64_IMAGE=${CIBW_MANYLINUX_X86_64_IMAGE:-"manylinux_2_28"}
+
# Do not auditwheels since tde uses torch's shared libraries.
export CIBW_REPAIR_WHEEL_COMMAND="tools/repair_wheel.sh {wheel} {dest_dir}"
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 000000000..fb00038dc
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,23 @@
+Docs
+==========
+
+
+## Building the docs
+
+To build and preview the docs run the following commands:
+
+```bash
+cd docs
+pip3 install -r requirements.txt
+make html
+python3 -m http.server 8082 --bind ::
+```
+
+Now you should be able to view the docs in your browser at the link provided in your terminal.
+
+To reload the preview after making changes, rerun:
+
+```bash
+make html
+python3 -m http.server 8082 --bind ::
+```
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 851694777..96d50ad98 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,4 +1,7 @@
-sphinx
+sphinx==5.0.0
+pyre-extensions
+sphinx-design
+sphinx_copybutton
# torch
# PyTorch Theme
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css
new file mode 100644
index 000000000..55fee00b4
--- /dev/null
+++ b/docs/source/_static/css/custom.css
@@ -0,0 +1,114 @@
+/**
+* Copyright (c) Meta Platforms, Inc. and affiliates.
+* All rights reserved.
+*
+* This source code is licensed under the BSD-style license found in the
+* LICENSE file in the root directory of this source tree.
+*/
+
+/* sphinx-design styles for cards/tabs */
+
+:root {
+ --sd-color-info: #ee4c2c;
+ --sd-color-info-highlight: #ee4c2c;
+ --sd-color-primary: #6c6c6d;
+ --sd-color-primary-highligt: #f3f4f7;
+ --sd-color-card-border-hover: #ee4c2c;
+ --sd-color-card-border: #f3f4f7;
+ --sd-color-card-background: #fff;
+ --sd-color-card-text: inherit;
+ --sd-color-card-header: transparent;
+ --sd-color-card-footer: transparent;
+ --sd-color-tabs-label-active: #ee4c2c;
+ --sd-color-tabs-label-hover: #ee4c2c;
+ --sd-color-tabs-label-inactive: #6c6c6d;
+ --sd-color-tabs-underline-active: #ee4c2c;
+ --sd-color-tabs-underline-hover: #fabdbd;
+ --sd-color-tabs-underline-inactive: transparent;
+ --sd-color-tabs-overline: rgb(222, 222, 222);
+ --sd-color-tabs-underline: rgb(222, 222, 222);
+}
+
+.sd-text-info {
+ color: #ee4c2c;
+}
+
+.sd-card-img-top {
+ background: #ee4c2c;
+ height: 5px !important;
+}
+
+.sd-card {
+ position: relative;
+ background-color: #fff;
+ opacity: 1.0;
+ border-radius: 0px;
+ width: 30%;
+ border: none;
+ padding-bottom: 0px;
+}
+
+.sd-card-img {
+ opacity: 0.5;
+ width: 200px;
+ padding: 0px;
+}
+
+.sd-card-img:hover {
+ opacity: 1.0;
+ background-color: #f3f4f7;
+}
+
+
+.sd-card:after {
+ display: block;
+ opacity: 1;
+ content: '';
+ border-bottom: solid 1px #ee4c2c;
+ background-color: #fff;
+ transform: scaleX(0);
+ transition: transform .250s ease-in-out;
+ transform-origin: 0% 50%;
+}
+
+.sd-card:hover {
+ background-color: #fff;
+ opacity: 1;
+ border-top: 1px solid #f3f4f7;
+ border-left: 1px solid #f3f4f7;
+ border-right: 1px solid #f3f4f7;
+}
+
+.sd-card:hover:after {
+ transform: scaleX(1);
+}
+
+.card-prerequisites:hover {
+ transition: none;
+ border: none;
+}
+
+.card-prerequisites:hover:after {
+ transition: none;
+ transform: none;
+}
+
+.card-prerequisites:after {
+ display: block;
+ content: '';
+ border-bottom: none;
+ background-color: #fff;
+ transform: none;
+ transition: none;
+ transform-origin: none;
+}
+
+details.sd-dropdown {
+ font-weight: 300;
+ width: auto;
+}
+
+.center-content {
+ display: flex;
+ justify-content: center;
+}
diff --git a/docs/source/_static/img/card-background.svg b/docs/source/_static/img/card-background.svg
new file mode 100644
index 000000000..d97193223
--- /dev/null
+++ b/docs/source/_static/img/card-background.svg
@@ -0,0 +1,13 @@
+
+
diff --git a/docs/source/_static/img/full_training_loop.png b/docs/source/_static/img/full_training_loop.png
new file mode 100644
index 000000000..221e7c387
Binary files /dev/null and b/docs/source/_static/img/full_training_loop.png differ
diff --git a/docs/source/_static/img/fused_backward_optimizer.png b/docs/source/_static/img/fused_backward_optimizer.png
new file mode 100644
index 000000000..3c3d3593e
Binary files /dev/null and b/docs/source/_static/img/fused_backward_optimizer.png differ
diff --git a/docs/source/_static/img/fused_embedding_tables.png b/docs/source/_static/img/fused_embedding_tables.png
new file mode 100644
index 000000000..6874e4d75
Binary files /dev/null and b/docs/source/_static/img/fused_embedding_tables.png differ
diff --git a/docs/source/_static/img/model_parallel.png b/docs/source/_static/img/model_parallel.png
new file mode 100644
index 000000000..81d18e9e1
Binary files /dev/null and b/docs/source/_static/img/model_parallel.png differ
diff --git a/docs/source/_static/img/sharding.png b/docs/source/_static/img/sharding.png
new file mode 100644
index 000000000..653eb7cdb
Binary files /dev/null and b/docs/source/_static/img/sharding.png differ
diff --git a/docs/source/_static/img/torchrec_forward.png b/docs/source/_static/img/torchrec_forward.png
new file mode 100644
index 000000000..156f13e8f
Binary files /dev/null and b/docs/source/_static/img/torchrec_forward.png differ
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
new file mode 100644
index 000000000..dbc08744a
--- /dev/null
+++ b/docs/source/_templates/layout.html
@@ -0,0 +1,10 @@
+{% extends "!layout.html" %}
+
+{% block footer %}
+{{ super() }}
+
+
+
+{% endblock %}
diff --git a/docs/source/concepts.rst b/docs/source/concepts.rst
new file mode 100644
index 000000000..61806512e
--- /dev/null
+++ b/docs/source/concepts.rst
@@ -0,0 +1,311 @@
+.. meta::
+ :description: TorchRec Concepts
+ :keywords: recommendation systems, sharding, distributed training, torchrec, embedding bags, embeddings, keyedjaggedtensor, row wise, table wise, column wise, table row wise, planner, sharder
+
+###################
+TorchRec Concepts
+###################
+
+In this section, we will learn about the key concepts of TorchRec,
+designed to optimize large-scale recommendation systems using PyTorch.
+We will learn how each concept works in detail and review how it is used
+with the rest of TorchRec.
+
+TorchRec has specific input/output data types of its modules to
+efficiently represent sparse features, including:
+
+- **JaggedTensor:** a wrapper around the lengths/offsets and values
+ tensors for a singular sparse feature.
+- **KeyedJaggedTensor:** efficiently represent multiple sparse
+ features, can think of it as multiple ``JaggedTensor``\s.
+- **KeyedTensor:** a wrapper around ``torch.Tensor`` that allows access
+ to tensor values through keys.
+
+With the goal of high performance and efficiency, the canonical
+``torch.Tensor`` is highly inefficient for representing sparse data.
+TorchRec introduces these new data types because they provide efficient
+storage and representation of sparse input data. As you will see later
+on, the ``KeyedJaggedTensor`` makes communication of input data in a
+distributed environment very efficient leading to one of the key
+performance advantages that TorchRec provides.
+
+In the end-to-end training loop, TorchRec comprises of the following
+main components:
+
+- **Planner:** Takes in the configuration of embedding tables,
+ environment setup, and generates an optimized sharding plan for the
+ model.
+
+- **Sharder:** Shards model according to sharding plan with different
+ sharding strategies including data-parallel, table-wise, row-wise,
+ table-wise-row-wise, column-wise, and table-wise-column-wise
+ sharding.
+
+- **DistributedModelParallel:** Combines sharder, optimizer, and
+ provides an entry point into the training the model in a distributed
+ manner.
+
+**************
+JaggedTensor
+**************
+
+A ``JaggedTensor`` represents a sparse feature through lengths, values,
+and offsets. It is called "jagged" because it efficiently represents
+data with variable-length sequences. In contrast, a canonical
+``torch.Tensor`` assumes that each sequence has the same length, which
+is often not the case with real world data. A ``JaggedTensor``
+facilitates the representation of such data without padding making it
+highly efficient.
+
+Key Components:
+
+- ``Lengths``: A list of integers representing the number of elements
+ for each entity.
+
+- ``Offsets``: A list of integers representing the starting index of
+ each sequence in the flattened values tensor. These provide an
+ alternative to lengths.
+
+- ``Values``: A 1D tensor containing the actual values for each entity,
+ stored contiguously.
+
+Here is a simple example demonstrating how each of the components would
+look like:
+
+.. code:: python
+
+ # User interactions:
+ # - User 1 interacted with 2 items
+ # - User 2 interacted with 3 items
+ # - User 3 interacted with 1 item
+ lengths = [2, 3, 1]
+ offsets = [0, 2, 5] # Starting index of each user's interactions
+ values = torch.Tensor([101, 102, 201, 202, 203, 301]) # Item IDs interacted with
+ jt = JaggedTensor(lengths=lengths, values=values)
+ # OR
+ jt = JaggedTensor(offsets=offsets, values=values)
+
+*******************
+KeyedJaggedTensor
+*******************
+
+A ``KeyedJaggedTensor`` extends the functionality of ``JaggedTensor`` by
+introducing keys (which are typically feature names) to label different
+groups of features, for example, user features and item features. This
+is the data type used in ``forward`` of ``EmbeddingBagCollection`` and
+``EmbeddingCollection`` as they are used to represent multiple features
+in a table.
+
+A ``KeyedJaggedTensor`` has an implied batch size, which is the number
+of features divided by the length of ``lengths`` tensor. The example
+below has a batch size of 2. Similar to a ``JaggedTensor``, the
+``offsets`` and ``lengths`` function in the same manner. You can also
+access the ``lengths``, ``offsets``, and ``values`` of a feature by
+accessing the key from the ``KeyedJaggedTensor``.
+
+.. code:: python
+
+ keys = ["user_features", "item_features"]
+ # Lengths of interactions:
+ # - User features: 2 users, with 2 and 3 interactions respectively
+ # - Item features: 2 items, with 1 and 2 interactions respectively
+ lengths = [2, 3, 1, 2]
+ values = torch.Tensor([11, 12, 21, 22, 23, 101, 201, 202])
+ # Create a KeyedJaggedTensor
+ kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
+ # Access the features by key
+ print(kjt["user_features"])
+ # Outputs user features
+ print(kjt["item_features"])
+
+*********
+Planner
+*********
+
+The TorchRec planner helps determine the best sharding configuration for
+a model. It evaluates multiple possibilities for sharding embedding
+tables and optimizes for performance. The planner performs the
+following:
+
+- Assesses the memory constraints of the hardware.
+- Estimates compute requirements based on memory fetches, such as
+ embedding lookups.
+- Addresses data-specific factors.
+- Considers other hardware specifics, such as bandwidth, to generate an
+ optimal sharding plan.
+
+To ensure accurate consideration of these factors, the Planner can
+incorporate data about the embedding tables, constraints, hardware
+information, and topology to help in generating an optimal plan.
+
+*****************************
+Sharding of EmbeddingTables
+*****************************
+
+TorchRec sharder provides multiple sharding strategies for various use
+cases, we outline some of the sharding strategies and how they work as
+well as their benefits and limitations. Generally, we recommend using
+the TorchRec planner to generate a sharding plan for you as it will find
+the optimal sharding strategy for each embedding table in your model.
+
+Each sharding strategy determines how to do the table split, whether the
+table should be cut up and how, whether to keep one or a few copies of
+some tables, and so on. Each piece of the table from the outcome of
+sharding, whether it is one embedding table or part of it, is referred
+to as a shard.
+
+.. figure:: _static/img/sharding.png
+ :alt: Visualizing the difference of sharding types offered in TorchRec
+ :align: center
+
+ *Figure 1: Visualizing the placement of table shards under different sharding schemes offered in TorchRec*
+
+Here is the list of all sharding types available in TorchRec:
+
+- Table-wise (TW): as the name suggests, embedding table is kept as a
+ whole piece and placed on one rank.
+
+- Column-wise (CW): the table is split along the ``emb_dim`` dimension,
+ for example, ``emb_dim=256`` is split into 4 shards: ``[64, 64, 64,
+ 64]``.
+
+- Row-wise (RW): the table is split along the ``hash_size`` dimension,
+ usually split evenly among all the ranks.
+
+- Table-wise-row-wise (TWRW): table is placed on one host, split
+ row-wise among the ranks on that host.
+
+- Grid-shard (GS): a table is CW sharded and each CW shard is placed
+ TWRW on a host.
+
+- Data parallel (DP): each rank keeps a copy of the table.
+
+Once sharded, the modules are converted to sharded versions of
+themselves, known as ``ShardedEmbeddingCollection`` and
+``ShardedEmbeddingBagCollection`` in TorchRec. These modules handle the
+communication of input data, embedding lookups, and gradients.
+
+****************************************************
+Distributed Training with TorchRec Sharded Modules
+****************************************************
+
+With many sharding strategies available, how do we determine which one
+to use? There is a cost associated with each sharding scheme, which in
+conjunction with model size and number of GPUs determines which sharding
+strategy is best for a model.
+
+Without sharding, where each GPU keeps a copy of the embedding table
+(DP), the main cost is computation in which each GPU looks up the
+embedding vectors in its memory in the forward pass and updates the
+gradients in the backward pass.
+
+With sharding, there is an added communication cost: each GPU needs to
+ask the other GPUs for embedding vector lookup and communicate the
+gradients computed as well. This is typically referred to as ``all2all``
+communication. In TorchRec, for input data on a given GPU, we determine
+where the embedding shard for each part of the data is located and send
+it to the target GPU. That target GPU then returns the embedding vectors
+back to the original GPU. In the backward pass, the gradients are sent
+back to the target GPU and the shards are updated accordingly with the
+optimizer.
+
+As described above, sharding requires us to communicate the input data
+and embedding lookups. TorchRec handles this in three main stages, we
+will refer to this as the sharded embedding module forward that is used
+in training and inference of a TorchRec model:
+
+- Feature All to All/Input distribution (``input_dist``)
+
+ - Communicate input data (in the form of a ``KeyedJaggedTensor``) to
+ the appropriate device containing relevant embedding table shard
+
+- Embedding Lookup
+
+ - Lookup embeddings with new input data formed after feature all to
+ all exchange
+
+- Embedding All to All/Output Distribution (``output_dist``)
+
+ - Communicate embedding lookup data back to the appropriate device
+ that asked for it (in accordance with the input data the device
+ received)
+
+- The backward pass does the same operations but in reverse order.
+
+The diagram below demonstrates how it works:
+
+.. figure:: _static/img/torchrec_forward.png
+ :alt: Visualizing the forward pass including the input_dist, lookup, and output_dist of a sharded TorchRec module
+ :align: center
+
+ *Figure 2: Forward pass of a table wise sharded table including the input_dist, lookup, and output_dist of a sharded TorchRec module*
+
+**************************
+DistributedModelParallel
+**************************
+
+All of the above culminates into the main entrypoint that TorchRec uses
+to shard and integrate the plan. At a high level,
+``DistributedModelParallel`` does the following:
+
+- Initializes the environment by setting up process groups and
+ assigning device type.
+
+- Uses default sharders if no sharders are provided, the default includes
+ ``EmbeddingBagCollectionSharder``.
+
+- Takes in the provided sharding plan, if none is provided, it
+ generates one.
+
+- Creates a sharded version of modules and replaces the original
+ modules with them, for example, converts ``EmbeddingCollection`` to
+ ``ShardedEmbeddingCollection``.
+
+- By default, wraps the ``DistributedModelParallel`` with
+ ``DistributedDataParallel`` to make the module both model and data
+ parallel.
+
+***********
+Optimizer
+***********
+
+TorchRec modules provide a seamless API to fuse the backwards pass and
+optimizer step in training, providing a significant optimization in
+performance and decreasing the memory used, alongside granularity in
+assigning distinct optimizers to distinct model parameters.
+
+.. figure:: _static/img/fused_backward_optimizer.png
+ :alt: Visualizing fusing of optimizer in backward to update sparse embedding table
+ :align: center
+
+ *Figure 3: Fusing embedding backward with sparse optimizer*
+
+***********
+Inference
+***********
+
+Inference environments are different from training, they are very
+sensitive to performance and the size of the model. There are two key
+differences TorchRec inference optimizes for:
+
+- **Quantization:** inference models are quantized for lower latency
+ and reduced model size. This optimization lets us use as few devices
+ as possible for inference to minimize latency.
+
+- **C++ environment:** to minimize latency even further, the model is
+ ran in a C++ environment.
+
+TorchRec provides the following to convert a TorchRec model into being
+inference ready:
+
+- APIs for quantizing the model, including optimizations automatically
+ with FBGEMM TBE
+- Sharding embeddings for distributed inference
+- Compiling the model to TorchScript (compatible in C++)
+
+*********
+See Also
+*********
+
+- `TorchRec Interactive Notebook using the concepts
+ `_
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 4011be287..0f64cefe7 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -30,19 +30,34 @@
# -- Project information -----------------------------------------------------
project = "TorchRec"
-copyright = "2022, Meta"
+copyright = "2024, Meta"
author = "Meta"
+try:
+ # pyre-ignore
+ version = "1.0.0" # TODO: Hardcode stable version for now
+except Exception:
+ # when run internally, we don't have a version yet
+ version = "0.0.0"
# The full version, including alpha/beta/rc tags
-release = "0.0.1"
-
+# First 3 as format is 0.x.x.*
+release = ".".join(version.split(".")[:3])
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
-extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"]
+extensions = [
+ "sphinx.ext.napoleon",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.doctest",
+ "sphinx.ext.intersphinx",
+ "sphinx_copybutton",
+ "sphinx.ext.mathjax",
+ "sphinx_design",
+]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -62,7 +77,21 @@
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#
+html_theme_options = {
+ "pytorch_project": "torchrec",
+ "display_version": True,
+ "logo_only": True,
+ "collapse_navigation": False,
+ "includehidden": True,
+}
+
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
+
+html_css_files = ["css/custom.css"]
diff --git a/docs/source/datatypes-api-reference.rst b/docs/source/datatypes-api-reference.rst
new file mode 100644
index 000000000..8454a15f2
--- /dev/null
+++ b/docs/source/datatypes-api-reference.rst
@@ -0,0 +1,22 @@
+Data Types
+-------------------
+
+
+TorchRec contains data types for representing embedding, otherwise known as sparse features.
+Sparse features are typically indices that are meant to be fed into embedding tables. For a given
+batch, the number of embedding lookup indices are variable. Therefore, there is a need for a **jagged**
+dimension to represent the variable amount of embedding lookup indices for a batch.
+
+This section covers the classes for the 3 TorchRec data types for representing sparse features:
+**JaggedTensor**, **KeyedJaggedTensor**, and **KeyedTensor**.
+
+.. automodule:: torchrec.sparse.jagged_tensor
+
+.. autoclass:: JaggedTensor
+ :members:
+
+.. autoclass:: KeyedJaggedTensor
+ :members:
+
+.. autoclass:: KeyedTensor
+ :members:
diff --git a/docs/source/high-level-arch.rst b/docs/source/high-level-arch.rst
new file mode 100644
index 000000000..2c2263b6c
--- /dev/null
+++ b/docs/source/high-level-arch.rst
@@ -0,0 +1,129 @@
+.. meta::
+ :description: TorchRec High Level Architecture
+ :keywords: recommendation systems, sharding, distributed training, torchrec, architecture
+
+##################################
+ TorchRec High Level Architecture
+##################################
+
+In this section, you will learn about the high-level architecture of
+TorchRec, designed to optimize large-scale recommendation systems using
+PyTorch. You will learn how TorchRec employs model parallelism to
+distribute complex models across multiple GPUs, enhancing memory
+management and GPU utilization, as well as get introduced to TorchRec's
+base components and sharding strategies.
+
+In effect, TorchRec provides parallelism primitives allowing hybrid data
+parallelism/model parallelism, embedding table sharding, planner to
+generate sharding plans, pipelined training, and more.
+
+****************************************************
+ TorchRec's Parallelism Strategy: Model Parallelism
+****************************************************
+
+As modern deep learning models have scaled, distributed deep learning
+has become required to successfully train models in sufficient time. In
+this paradigm, two main approaches have been developed: data parallelism
+and model parallelism. TorchRec focuses on the latter for the sharding
+of embedding tables.
+
+.. figure:: _static/img/model_parallel.png
+ :alt: Visualizing the difference of sharding a model in model parallel or data parallel approach
+ :align: center
+
+ *Figure 1. Comparison between model parallelism and data parallelism approach*
+
+As you can see in the diagram above, model parallelism and data
+parallelism are two approaches to distribute workloads across multiple
+GPUs,
+
+- **Model Parallelism**
+
+ - Divide the model into segments and distribute them across GPUs
+ - Each segment processes data independently
+ - Suitable for large models that don't fit on a single GPU
+
+- **Data Parallel**
+
+ - Distribute the copies of entire model on each GPU
+ - Each GPU processes a subset of the data and contributes to the
+ overall computation
+ - Effecive for models that fit on single GPU but need to handle
+ large datasets
+
+- **Benefits of Model Parallelism**
+
+ - Optimizes memory usage and computational efficiency for large
+ models
+ - Particularly beneficial for recommendation systems with large
+ embedding tables
+ - Enables parallel computation of embeddings in DLRM-type
+ architectures
+
+******************
+ Embedding Tables
+******************
+
+For TorchRec to figure out what to recommend, we need to be able to
+represent entities and their relationships, this is what embeddings are
+used for. Embeddings are vectors of real numbers in a high dimensional
+space used to represent meaning in complex data like words, images, or
+users. An embedding table is an aggregation of multiple embeddings into
+one matrix. Most commonly, embedding tables are represented as a 2D
+matrix with dimensions (B, N).
+
+- *B* is the number of embeddings stored by the table
+- *N* is number of dimensions per embedding.
+
+Each of *B* can also be referred to as an ID (representing information
+such as movie title, user, ad, and so on), when accessing an ID we are
+returned the corresponding embedding vector which has size of embedding
+dimension *N*.
+
+There is also the choice of pooling embeddings, often, we’re looking up
+multiple rows for a given feature which gives rise to the question of
+what we do with looking up multiple embedding vectors. Pooling is a
+common technique where we combine the embedding vectors, usually through
+sum or mean of the rows, to produce one embedding vector. This is the
+main difference between the PyTorch ``nn.Embedding`` and
+``nn.EmbeddingBag``.
+
+PyTorch represents embeddings through ``nn.Embedding`` and
+``nn.EmbeddingBag``. Building on these modules, TorchRec introduces
+``EmbeddingCollection`` and ``EmbeddingBagCollection``, which are
+collections of the corresponding PyTorch modules. This extension enables
+TorchRec to batch tables and perform lookups on multiple embeddings in a
+single kernel call, improving efficiency.
+
+Here is the end-to-end flow diagram that describes how embeddings are
+used in the training process for recommendation models:
+
+.. figure:: _static/img/full_training_loop.png
+ :alt: Demonstrating the full training loop from embedding lookup to optimizer update in backward
+ :align: center
+
+ *Figure 2. TorchRec End-to-end Embedding Flow*
+
+In the diagram above, we show the general TorchRec end to end embedding
+lookup process,
+
+- In the forward pass we do the embedding lookup and pooling
+- In the backward pass we compute the gradients of the output lookups
+ and pass them into the optimizer to update the embedding tables
+
+**Note here, the embeddings gradients are grayed out since we do not
+fully materialize these into memory and instead fuse them with the
+optimizer update. This results in a significant memory reduction which
+we detail later in the optimizer concepts section.**
+
+We recommend going through the TorchRec Concepts page to get a
+understanding of the fundamentals of how everything ties together
+end-to-end. It contains lots of useful information to get the most out
+of TorchRec.
+
+**********
+ See also
+**********
+
+- `What is Distributed Data Parallel (DDP) Tutorial
+ `_
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 09f3a6f2c..c6fa49282 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -3,51 +3,109 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
+.. meta::
+ :description: TorchRec documentation homepage
+ :keywords: recommendation systems, sharding, distributed training
+
Welcome to the TorchRec documentation!
======================================
-TorchRec is a PyTorch domain library built to provide common
-sparsity & parallelism primitives needed for large-scale recommender
-systems (RecSys). It allows authors to train models with large
-embedding tables sharded across many GPUs.
+TorchRec is a specialized library within the PyTorch ecosystem,
+tailored for building, scaling, and deploying large-scale
+**recommendation systems**, a niche not directly addressed by standard
+PyTorch. TorchRec offers advanced features such as complex sharding
+techniques for massive embedding tables, and enhanced distributed
+training capabilities.
+
+Getting Started
+---------------
+
+Topics in this section will help you get started with TorchRec.
+
+.. grid:: 3
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ TorchRec Overview
+ :img-top: _static/img/card-background.svg
+ :link: overview.html
+ :link-type: url
+
+ A short intro to TorchRec and why you need it.
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ Set up TorchRec
+ :img-top: _static/img/card-background.svg
+ :link: setup-torchrec.html
+ :link-type: url
+
+ Learn how to install and start using TorchRec
+ in your environment.
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ Getting Started with TorchRec Tutorial
+ :img-top: _static/img/card-background.svg
+ :link: https://colab.research.google.com/github/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb
+ :link-type: url
+
+ Follow our interactive step-by-step tutorial
+ to learn how to use TorchRec in a real-life
+ example.
+
+
-For installation instructions, visit
+How to Contribute
+-----------------
-https://github.com/pytorch/torchrec#readme
+We welcome contributions and feedback from the PyTorch community!
+If you are interested in helping improve the TorchRec project, here is
+how you can contribute:
-Tutorial
---------
-In this tutorial, we introduce the primary torchRec
-API called DistributedModelParallel, or DMP.
-Like pytorch’s DistributedDataParallel,
-DMP wraps a model to enable distributed training.
+1. **Visit Our** `GitHub Repository `__:
+ There you can find the source code, issues, and ongoing projects.
-* `Tutorial Source `_
-* Open in `Google Colab `_
+1. **Submit Feedback or Issues**: If you encounter any bugs or have
+ suggestions for improvements, please submit an issue through the
+ `GitHub issue tracker `__.
-TorchRec API
-------------
+1. **Propose changes**: Fork the repository and submit pull requests.
+ Whether it's fixing a bug, adding new features, or improving
+ documentation, your contributions are always welcome! Please make sure to
+ review our `CONTRIBUTING.md `__
+
+|
+|
+
+.. container:: center-content
+
+ .. button-link:: https://github.com/pytorch/torchrec
+ :color: info
+
+ :octicon:`mark-github` Go to TorchRec Repo
+
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Introduction
+ :hidden:
+
+ overview.rst
+ high-level-arch.rst
+ concepts.rst
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Getting Started
+ :hidden:
+
+ setup-torchrec.rst
.. toctree::
:maxdepth: 2
- :caption: Contents:
-
- torchrec.datasets.rst
- torchrec.datasets.scripts.rst
- torchrec.distributed.rst
- torchrec.distributed.planner.rst
- torchrec.distributed.sharding.rst
- torchrec.fx.rst
- torchrec.inference.rst
- torchrec.models.rst
- torchrec.modules.rst
- torchrec.optim.rst
- torchrec.quant.rst
- torchrec.sparse.rst
-
-Indices and tables
-==================
-
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
+ :caption: API References
+ :hidden:
+
+ datatypes-api-reference.rst
+ modules-api-reference.rst
+ planner-api-reference.rst
+ model-parallel-api-reference.rst
+ inference-api-reference.rst
diff --git a/docs/source/inference-api-reference.rst b/docs/source/inference-api-reference.rst
new file mode 100644
index 000000000..0e10067c7
--- /dev/null
+++ b/docs/source/inference-api-reference.rst
@@ -0,0 +1,19 @@
+Inference
+----------------------------------
+
+TorchRec provides easy-to-use APIs for transforming an authored TorchRec model
+into an optimized inference model for distributed inference, via eager module swaps.
+
+This transforms TorchRec modules like ``EmbeddingBagCollection`` in the model to
+a quantized, sharded version that can be compiled using torch.fx and TorchScript
+for inference in a C++ environment.
+
+The intended use is calling ``quantize_inference_model`` on the model followed by
+``shard_quant_model``.
+
+.. codeblock::
+
+.. automodule:: torchrec.inference.modules
+
+.. autofunction:: quantize_inference_model
+.. autofunction:: shard_quant_model
diff --git a/docs/source/model-parallel-api-reference.rst b/docs/source/model-parallel-api-reference.rst
new file mode 100644
index 000000000..196fd27fb
--- /dev/null
+++ b/docs/source/model-parallel-api-reference.rst
@@ -0,0 +1,10 @@
+Model Parallel
+----------------------------------
+
+``DistributedModelParallel`` is the main API for distributed training with TorchRec optimizations.
+
+
+.. automodule:: torchrec.distributed.model_parallel
+
+.. autoclass:: DistributedModelParallel
+ :members:
diff --git a/docs/source/modules-api-reference.rst b/docs/source/modules-api-reference.rst
new file mode 100644
index 000000000..98cff7ad8
--- /dev/null
+++ b/docs/source/modules-api-reference.rst
@@ -0,0 +1,30 @@
+Modules
+----------------------------------
+
+Standard TorchRec modules represent collections of embedding tables:
+
+* ``EmbeddingBagCollection`` is a collection of ``torch.nn.EmbeddingBag``
+* ``EmbeddingCollection`` is a collection of ``torch.nn.Embedding``
+
+These modules are constructed through standardized config classes:
+
+* ``EmbeddingBagConfig`` for ``EmbeddingBagCollection``
+* ``EmbeddingConfig`` for ``EmbeddingCollection``
+
+.. automodule:: torchrec.modules.embedding_configs
+
+.. autoclass:: EmbeddingBagConfig
+ :show-inheritance:
+
+.. autoclass:: EmbeddingConfig
+ :show-inheritance:
+
+.. autoclass:: BaseEmbeddingConfig
+
+.. automodule:: torchrec.modules.embedding_modules
+
+.. autoclass:: EmbeddingBagCollection
+ :members:
+
+.. autoclass:: EmbeddingCollection
+ :members:
diff --git a/docs/source/overview.rst b/docs/source/overview.rst
new file mode 100644
index 000000000..0098c14df
--- /dev/null
+++ b/docs/source/overview.rst
@@ -0,0 +1,23 @@
+.. _overview_label:
+
+==================
+TorchRec Overview
+==================
+
+TorchRec is the PyTorch recommendation system library, designed to provide common primitives
+for creating state-of-the-art personalization models and a path to production. TorchRec is
+widely adopted in many Meta production recommendation system models for training and inference workflows.
+
+Why TorchRec?
+------------------
+
+TorchRec is designed to address the unique challenges of building, scaling and deploying massive,
+large-scale recommendation system models, which is not a focus of regular PyTorch. More specifically,
+TorchRec provides the following primitives for general recommendation systems:
+
+- **Specialized Components**: TorchRec provides simplistic, specialized modules that are common in authoring recommendation systems, with a focus on embedding tables
+- **Advanced Sharding Techniques**: TorchRec provides flexible and customizable methods for sharding massive embedding tables: Row-Wise, Column-Wise, Table-Wise, and so on. TorchRec can automatically determine the best plan for a device topology for efficient training and memory balance
+- **Distributed Training**: While PyTorch supports basic distributed training, TorchRec extends these capabilities with more sophisticated model parallelism techniques specifically designed for the massive scale of recommendation systems
+- **Incredibly Optimized**: TorchRec training and inference components are incredibly optimized on top of FBGEMM. After all, TorchRec powers some of the largest recommendation system models at Meta
+- **Frictionless Path to Deployment**: TorchRec provides simple APIs for transforming a trained model for inference and loading it into a C++ environment for the most optimal inference model
+- **Integration with PyTorch Ecosystem**: TorchRec is built on top of PyTorch, meaning it integrates seamlessly with existing PyTorch code, tools, and workflows. This allows developers to leverage their existing knowledge and codebase while utilizing advanced features for recommendation systems. By being a part of the PyTorch ecosystem, TorchRec benefits from the robust community support, continuous updates, and improvements that come with PyTorch.
diff --git a/docs/source/planner-api-reference.rst b/docs/source/planner-api-reference.rst
new file mode 100644
index 000000000..e6cc7f4d2
--- /dev/null
+++ b/docs/source/planner-api-reference.rst
@@ -0,0 +1,50 @@
+Planner
+----------------------------------
+
+The TorchRec Planner is responsible for determining the most performant, balanced
+sharding plan for distributed training and inference.
+
+The main API for generating a sharding plan is ``EmbeddingShardingPlanner.plan``
+
+.. automodule:: torchrec.distributed.types
+
+.. autoclass:: ShardingPlan
+ :members:
+
+.. automodule:: torchrec.distributed.planner.planners
+
+.. autoclass:: EmbeddingShardingPlanner
+ :members:
+
+.. automodule:: torchrec.distributed.planner.enumerators
+
+.. autoclass:: EmbeddingEnumerator
+ :members:
+
+.. automodule:: torchrec.distributed.planner.partitioners
+
+.. autoclass:: GreedyPerfPartitioner
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.storage_reservations
+
+.. autoclass:: HeuristicalStorageReservation
+ :members:
+
+.. automodule:: torchrec.distributed.planner.proposers
+
+.. autoclass:: GreedyProposer
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.shard_estimators
+
+.. autoclass:: EmbeddingPerfEstimator
+ :members:
+
+
+.. automodule:: torchrec.distributed.planner.shard_estimators
+
+.. autoclass:: EmbeddingStorageEstimator
+ :members:
diff --git a/docs/source/setup-torchrec.rst b/docs/source/setup-torchrec.rst
new file mode 100644
index 000000000..7a2cbf969
--- /dev/null
+++ b/docs/source/setup-torchrec.rst
@@ -0,0 +1,139 @@
+
+===================
+Setting up TorchRec
+===================
+
+In this section, we will:
+
+* Understand requirements for using TorchRec
+* Set up an environment to integrate TorchRec
+* Run basic TorchRec code
+
+
+System Requirements
+-------------------
+
+TorchRec is routinely tested on AWS Linux only and should work in similar environments.
+Below demonstrates the compatability matrix that is currently tested:
+
+.. list-table::
+ :widths: 25 75
+ :header-rows: 0
+
+ * - Python Version
+ - 3.9, 3.10, 3.11, 3.12
+ * - Compute Platform
+ - CPU, CUDA 11.8, CUDA 12.1, CUDA 12.4
+
+Aside from those requirements, TorchRec's core dependencies are PyTorch and FBGEMM.
+If your system is compatible with both libraries generally, then it should be sufficient for TorchRec.
+
+* `PyTorch requirements `_
+* `FBGEMM requirements `_
+
+
+Version Compatability
+---------------------
+
+TorchRec and FBGEMM have matching version numbers that are tested together upon release:
+
+* TorchRec 1.0 is compatible with FBGEMM 1.0
+* TorchRec 0.8 is compatible with FBGEMM 0.8
+* TorchRec 0.8 may not be compatible with FBGEMM 0.7
+
+Furthermore, TorchRec and FBGEMM are released only when a new PyTorch release happens.
+Therefore, specific versions of TorchRec and FBGEMM should correspond to a specific PyTorch version:
+
+* TorchRec 1.0 is compatible with PyTorch 2.5
+* TorchRec 0.8 is compatible with PyTorch 2.4
+* TorchRec 0.8 may not be compatible with PyTorch 2.3
+
+Installation
+------------
+Below we show installations for CUDA 12.1 as an example. For CPU, CUDA 11.8, or CUDA 12.4, swap ``cu121`` for ``cpu``, ``cu118``, or ``cu124`` respectively.
+
+.. tab-set::
+
+ .. tab-item:: **Stable via pytorch.org**
+
+ .. code-block:: bash
+
+ pip install torch --index-url https://download.pytorch.org/whl/cu121
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
+ pip install torchmetrics==1.0.3
+ pip install torchrec --index-url https://download.pytorch.org/whl/cu121
+
+ .. tab-item:: **Stable via PyPI (Only for CUDA 12.4)**
+
+ .. code-block:: bash
+
+ pip install torch
+ pip install fbgemm-gpu
+ pip install torchrec
+
+ .. tab-item:: **Nightly**
+
+ .. code-block:: bash
+
+ pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
+ pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121
+ pip install torchmetrics==1.0.3
+ pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121
+
+ .. tab-item:: **Building From Source**
+
+ You also have the ability to build TorchRec from source to develop with the latest
+ changes in TorchRec. To build from source, check out this `reference `_.
+
+
+Run a Simple TorchRec Example
+------------------------------
+Now that we have TorchRec properly set up, let's run some TorchRec code!
+Below, we'll run a simple forward pass with TorchRec data types: ``KeyedJaggedTensor`` and ``EmbeddingBagCollection``:
+
+.. code-block:: python
+
+ import torch
+
+ import torchrec
+ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
+
+ ebc = torchrec.EmbeddingBagCollection(
+ device="cpu",
+ tables=[
+ torchrec.EmbeddingBagConfig(
+ name="product_table",
+ embedding_dim=16,
+ num_embeddings=4096,
+ feature_names=["product"],
+ pooling=torchrec.PoolingType.SUM,
+ ),
+ torchrec.EmbeddingBagConfig(
+ name="user_table",
+ embedding_dim=16,
+ num_embeddings=4096,
+ feature_names=["user"],
+ pooling=torchrec.PoolingType.SUM,
+ )
+ ]
+ )
+
+ product_jt = JaggedTensor(
+ values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
+ )
+ user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
+
+ # Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?
+ kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
+
+ print("Call EmbeddingBagCollection Forward: ", ebc(kjt))
+
+Save the above code to a file named ``torchrec_example.py``. Then, you should be able to
+execute it from your terminal with:
+
+.. code-block:: bash
+
+ python torchrec_example.py
+
+You should see the output ``KeyedTensor`` with the resulting embeddings.
+Congrats! You have correctly installed and ran your first TorchRec program!
diff --git a/docs/source/torchrec.datasets.rst b/docs/source/torchrec.datasets.rst
deleted file mode 100644
index a16be3c4c..000000000
--- a/docs/source/torchrec.datasets.rst
+++ /dev/null
@@ -1,36 +0,0 @@
-torchrec.datasets
-=================
-
-.. automodule:: torchrec.datasets
-
-torchrec.datasets.criteo
-------------------------
-
-.. automodule:: torchrec.datasets.criteo
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.movielens
----------------------------
-
-.. automodule:: torchrec.datasets.movielens
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.random
-------------------------
-
-.. automodule:: torchrec.datasets.random
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.utils
------------------------
-
-.. automodule:: torchrec.datasets.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.datasets.scripts.rst b/docs/source/torchrec.datasets.scripts.rst
deleted file mode 100644
index 947250d67..000000000
--- a/docs/source/torchrec.datasets.scripts.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-torchrec.datasets.scripts
-=========================
-
-.. automodule:: torchrec.datasets.scripts
-
-
-
-torchrec.datasets.scripts.contiguous\_preproc\_criteo
------------------------------------------------------
-
-.. automodule:: torchrec.datasets.scripts.contiguous_preproc_criteo
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.datasets.scripts.npy\_preproc\_criteo
-----------------------------------------------
-
-.. automodule:: torchrec.datasets.scripts.npy_preproc_criteo
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.planner.rst b/docs/source/torchrec.distributed.planner.rst
deleted file mode 100644
index 82200125e..000000000
--- a/docs/source/torchrec.distributed.planner.rst
+++ /dev/null
@@ -1,93 +0,0 @@
-torchrec.distributed.planner
-============================
-
-.. automodule:: torchrec.distributed.planner
-
-
-torchrec.distributed.planner.constants
---------------------------------------
-
-.. automodule:: torchrec.distributed.planner.constants
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.enumerators
-----------------------------------------
-
-.. automodule:: torchrec.distributed.planner.enumerators
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.partitioners
------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.partitioners
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.perf\_models
------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.perf_models
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.planners
--------------------------------------
-
-.. automodule:: torchrec.distributed.planner.planners
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.proposers
---------------------------------------
-
-.. automodule:: torchrec.distributed.planner.proposers
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.shard\_estimators
-----------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.shard_estimators
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.stats
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.stats
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.storage\_reservations
---------------------------------------------------
-
-.. automodule:: torchrec.distributed.planner.storage_reservations
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.types
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.planner.utils
-----------------------------------
-
-.. automodule:: torchrec.distributed.planner.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.rst b/docs/source/torchrec.distributed.rst
deleted file mode 100644
index 9fdcb4282..000000000
--- a/docs/source/torchrec.distributed.rst
+++ /dev/null
@@ -1,126 +0,0 @@
-torchrec.distributed
-====================
-
-.. automodule:: torchrec.distributed
-
-
-torchrec.distributed.collective\_utils
---------------------------------------
-
-.. automodule:: torchrec.distributed.collective_utils
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.comm
--------------------------
-
-.. automodule:: torchrec.distributed.comm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.comm\_ops
-------------------------------
-
-.. automodule:: torchrec.distributed.comm_ops
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.dist\_data
--------------------------------
-
-.. automodule:: torchrec.distributed.dist_data
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding
-------------------------------
-
-.. automodule:: torchrec.distributed.embedding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_lookup
---------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_lookup
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_sharding
-----------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embedding\_types
--------------------------------------
-
-.. automodule:: torchrec.distributed.embedding_types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.embeddingbag
----------------------------------
-
-.. automodule:: torchrec.distributed.embeddingbag
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.grouped\_position\_weighted
-------------------------------------------------
-
-.. automodule:: torchrec.distributed.grouped_position_weighted
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.model\_parallel
-------------------------------------
-
-.. automodule:: torchrec.distributed.model_parallel
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.quant\_embeddingbag
-----------------------------------------
-
-.. automodule:: torchrec.distributed.quant_embeddingbag
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.train\_pipeline
-------------------------------------
-
-.. automodule:: torchrec.distributed.train_pipeline
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.types
---------------------------
-
-.. automodule:: torchrec.distributed.types
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.utils
---------------------------
-
-.. automodule:: torchrec.distributed.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.distributed.sharding.rst b/docs/source/torchrec.distributed.sharding.rst
deleted file mode 100644
index d0e7162ae..000000000
--- a/docs/source/torchrec.distributed.sharding.rst
+++ /dev/null
@@ -1,63 +0,0 @@
-torchrec.distributed.sharding
-=============================
-
-.. automodule:: torchrec.distributed.sharding
-
-
-torchrec.distributed.sharding.cw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.cw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.dist\_data
--------------------------------
-
-.. automodule:: torchrec.distributed.dist_data
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.dp\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.dp_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.sharding.rw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.rw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.distributed.sharding.tw\_sharding
-------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.tw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.twcw\_sharding
---------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.twcw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.distributed.sharding.twrw\_sharding
---------------------------------------------
-
-.. automodule:: torchrec.distributed.sharding.twrw_sharding
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.fx.rst b/docs/source/torchrec.fx.rst
deleted file mode 100644
index b06656fbc..000000000
--- a/docs/source/torchrec.fx.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.fx
-===========
-
-.. automodule:: torchrec.fx
-
-torchrec.fx.tracer
-------------------
-
-.. automodule:: torchrec.fx.tracer
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.fx
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.inference.rst b/docs/source/torchrec.inference.rst
deleted file mode 100644
index 846a520ec..000000000
--- a/docs/source/torchrec.inference.rst
+++ /dev/null
@@ -1,28 +0,0 @@
-torchrec.inference
-===========
-
-.. automodule:: torchrec.inference
-
-torchrec.inference.model_packager
-------------------
-
-.. automodule:: torchrec.inference.model_packager
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.inference.modules
-------------------
-
-.. automodule:: torchrec.inference.modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.inference
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.models.rst b/docs/source/torchrec.models.rst
deleted file mode 100644
index 4d3ad3040..000000000
--- a/docs/source/torchrec.models.rst
+++ /dev/null
@@ -1,28 +0,0 @@
-torchrec.models
-===============
-
-.. automodule:: torchrec.models
-
-torchrec.models.deepfm
-----------------------
-
-.. automodule:: torchrec.models.deepfm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.models.dlrm
---------------------
-
-.. automodule:: torchrec.models.dlrm
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.models
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.modules.rst b/docs/source/torchrec.modules.rst
deleted file mode 100644
index 95dcf35db..000000000
--- a/docs/source/torchrec.modules.rst
+++ /dev/null
@@ -1,85 +0,0 @@
-torchrec.modules
-================
-
-.. automodule:: torchrec.modules
-
-torchrec.modules.activation
----------------------------
-
-.. automodule:: torchrec.modules.activation
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.crossnet
--------------------------
-
-.. automodule:: torchrec.modules.crossnet
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.deepfm
------------------------
-
-.. automodule:: torchrec.modules.deepfm
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.embedding\_configs
------------------------------------
-
-.. automodule:: torchrec.modules.embedding_configs
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.embedding\_modules
------------------------------------
-
-.. automodule:: torchrec.modules.embedding_modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.feature\_processor
------------------------------------
-
-.. automodule:: torchrec.modules.feature_processor
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.lazy\_extension
---------------------------------
-
-.. automodule:: torchrec.modules.lazy_extension
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.modules.mlp
---------------------
-
-.. automodule:: torchrec.modules.mlp
- :members:
- :undoc-members:
- :show-inheritance:
-
-
-torchrec.modules.utils
-----------------------
-
-.. automodule:: torchrec.modules.utils
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.modules
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.optim.rst b/docs/source/torchrec.optim.rst
deleted file mode 100644
index df3a147d0..000000000
--- a/docs/source/torchrec.optim.rst
+++ /dev/null
@@ -1,44 +0,0 @@
-torchrec.optim
-==============
-
-.. automodule:: torchrec.optim
-
-torchrec.optim.clipping
------------------------
-
-.. automodule:: torchrec.optim.clipping
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.fused
---------------------
-
-.. automodule:: torchrec.optim.fused
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.keyed
---------------------
-
-.. automodule:: torchrec.optim.keyed
- :members:
- :undoc-members:
- :show-inheritance:
-
-torchrec.optim.warmup
----------------------
-
-.. automodule:: torchrec.optim.warmup
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.optim
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.quant.rst b/docs/source/torchrec.quant.rst
deleted file mode 100644
index df4b34bd4..000000000
--- a/docs/source/torchrec.quant.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.quant
-==============
-
-.. automodule:: torchrec.quant
-
-torchrec.quant.embedding\_modules
----------------------------------
-
-.. automodule:: torchrec.quant.embedding_modules
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.quant
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/torchrec.sparse.rst b/docs/source/torchrec.sparse.rst
deleted file mode 100644
index efbd3eaec..000000000
--- a/docs/source/torchrec.sparse.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-torchrec.sparse
-===============
-
-.. automodule:: torchrec.sparse
-
-torchrec.sparse.jagged\_tensor
-------------------------------
-
-.. automodule:: torchrec.sparse.jagged_tensor
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: torchrec.sparse
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/examples/bert4rec/bert4rec_main.py b/examples/bert4rec/bert4rec_main.py
index 1f3f1319b..f5c8ccff5 100644
--- a/examples/bert4rec/bert4rec_main.py
+++ b/examples/bert4rec/bert4rec_main.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
@@ -26,6 +28,7 @@
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ModuleSharder
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from tqdm import tqdm
@@ -497,11 +500,12 @@ def main(argv: List[str]) -> None:
],
)
dense_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: optim.Adam(
params, lr=args.lr, weight_decay=args.weight_decay
),
)
+
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
else:
device_ids = [rank] if backend == "nccl" else None
diff --git a/examples/bert4rec/bert4rec_metrics.py b/examples/bert4rec/bert4rec_metrics.py
index 09140d62e..6ad244fdc 100644
--- a/examples/bert4rec/bert4rec_metrics.py
+++ b/examples/bert4rec/bert4rec_metrics.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Dict, List
import torch
diff --git a/examples/bert4rec/data/bert4rec_movielens_datasets.py b/examples/bert4rec/data/bert4rec_movielens_datasets.py
index fb912126d..18144cc48 100644
--- a/examples/bert4rec/data/bert4rec_movielens_datasets.py
+++ b/examples/bert4rec/data/bert4rec_movielens_datasets.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import random
from collections import Counter
from pathlib import Path
diff --git a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
index daca1c4e7..06c8deef8 100644
--- a/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
+++ b/examples/bert4rec/data/tests/test_bert4rec_movielens_datasets.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
from ..bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe
diff --git a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
index af7a3b22f..a9303d218 100644
--- a/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
+++ b/examples/bert4rec/dataloader/bert4rec_movielens_dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Any, Dict, Tuple
import pandas as pd
diff --git a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
index 0a5526af3..a47c0fe7d 100644
--- a/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
+++ b/examples/bert4rec/dataloader/tests/test_bert4rec_movielens_dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
from ...data.bert4rec_movielens_datasets import Bert4RecPreprocsser, get_raw_dataframe
diff --git a/examples/bert4rec/models/bert4rec.py b/examples/bert4rec/models/bert4rec.py
index 0ca8b83b0..b90ac4305 100644
--- a/examples/bert4rec/models/bert4rec.py
+++ b/examples/bert4rec/models/bert4rec.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import copy
import math
diff --git a/examples/bert4rec/models/tests/test_bert4rec.py b/examples/bert4rec/models/tests/test_bert4rec.py
index ba46c299d..5149f188b 100644
--- a/examples/bert4rec/models/tests/test_bert4rec.py
+++ b/examples/bert4rec/models/tests/test_bert4rec.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import unittest
diff --git a/examples/bert4rec/tests/test_bert4rec_main.py b/examples/bert4rec/tests/test_bert4rec_main.py
index d53198c6e..705f4b64c 100644
--- a/examples/bert4rec/tests/test_bert4rec_main.py
+++ b/examples/bert4rec/tests/test_bert4rec_main.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/datasets/README b/examples/datasets/README
deleted file mode 100644
index 74210008a..000000000
--- a/examples/datasets/README
+++ /dev/null
@@ -1 +0,0 @@
-Datasets under this directory are prototyping with libaries under active development (e.g. TorchArrow DataFrame) which may not have best performance yet, and subject to change.
diff --git a/examples/datasets/criteo_dataframes.py b/examples/datasets/criteo_dataframes.py
deleted file mode 100644
index dec708602..000000000
--- a/examples/datasets/criteo_dataframes.py
+++ /dev/null
@@ -1,101 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Iterable, List, Tuple, Union
-
-import numpy as np
-import torch.utils.data.datapipes as dp
-import torcharrow as ta
-import torcharrow.dtypes as dt
-from torch.utils.data import IterDataPipe
-from torchrec.datasets.criteo import (
- CAT_FEATURE_COUNT,
- CriteoIterDataPipe,
- DEFAULT_CAT_NAMES,
- DEFAULT_INT_NAMES,
- INT_FEATURE_COUNT,
-)
-from torchrec.datasets.utils import safe_cast
-
-DTYPE = dt.Struct(
- [
- dt.Field("labels", dt.int8),
- dt.Field(
- "dense_features",
- dt.Struct(
- [
- dt.Field(int_name, dt.Int32(nullable=True))
- for int_name in DEFAULT_INT_NAMES
- ]
- ),
- ),
- dt.Field(
- "sparse_features",
- dt.Struct(
- [
- dt.Field(cat_name, dt.Int32(nullable=True))
- for cat_name in DEFAULT_CAT_NAMES
- ]
- ),
- ),
- ]
-)
-
-
-def _torcharrow_row_mapper(
- row: List[str],
-) -> Tuple[int, Tuple[int, ...], Tuple[int, ...]]:
- # TODO: Fix safe_cast type annotation
- label = int(safe_cast(row[0], int, 0))
- dense = tuple(
- (int(safe_cast(row[i], int, 0)) for i in range(1, 1 + INT_FEATURE_COUNT))
- )
- sparse = tuple(
- (
- int(safe_cast(row[i], str, "0") or "0", 16)
- for i in range(
- 1 + INT_FEATURE_COUNT, 1 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT
- )
- )
- )
- # TorchArrow doesn't support uint32, but we can save memory
- # by not using int64. Numpy will automatically handle sparse values >= 2 ** 31.
- sparse = tuple(np.array(sparse, dtype=np.int32).tolist())
-
- return (label, dense, sparse)
-
-
-def criteo_dataframes_from_tsv(
- paths: Union[str, Iterable[str]],
- *,
- batch_size: int = 128,
-) -> IterDataPipe:
- """
- Load Criteo dataset (Kaggle or Terabyte) as TorchArrow DataFrame streams from TSV file(s)
-
- This implementaiton is inefficient and is used for prototype and test only.
-
- Args:
- paths (str or Iterable[str]): local paths to TSV files that constitute
- the Kaggle or Criteo 1TB dataset.
-
- Example::
-
- datapipe = criteo_dataframes_from_tsv(
- ["/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv"]
- )
- for df in datapipe:
- print(df)
- """
- if isinstance(paths, str):
- paths = [paths]
-
- datapipe = CriteoIterDataPipe(paths, row_mapper=_torcharrow_row_mapper)
- datapipe = dp.iter.Batcher(datapipe, batch_size)
- datapipe = dp.iter.Mapper(datapipe, lambda batch: ta.dataframe(batch, dtype=DTYPE))
-
- return datapipe
diff --git a/examples/datasets/tests/test_criteo_dataframes.py b/examples/datasets/tests/test_criteo_dataframes.py
deleted file mode 100644
index 18cd5079b..000000000
--- a/examples/datasets/tests/test_criteo_dataframes.py
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import contextlib
-
-import torcharrow as ta
-from torch.utils.data import IterDataPipe
-from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest
-
-from ..criteo_dataframes import criteo_dataframes_from_tsv
-
-
-class CriteoDataFramesTest(CriteoTest):
- BATCH_SIZE = 4
-
- def test_single_file(self) -> None:
- with self._create_dataset_tsv() as dataset_pathname:
- dataset = criteo_dataframes_from_tsv(
- dataset_pathname, batch_size=self.BATCH_SIZE
- )
-
- self._validate_dataset(dataset, 10)
-
- def test_multiple_files(self) -> None:
- with contextlib.ExitStack() as stack:
- pathnames = [
- stack.enter_context(self._create_dataset_tsv()) for _ in range(3)
- ]
- dataset = criteo_dataframes_from_tsv(pathnames, batch_size=self.BATCH_SIZE)
-
- self._validate_dataset(dataset, 30)
-
- def _validate_dataset(
- self, dataset: IterDataPipe, expected_total_length: int
- ) -> None:
- last_batch = False
- total_length = 0
-
- for df in dataset:
- self.assertFalse(last_batch)
- self.assertTrue(isinstance(df, ta.DataFrame))
- self.assertLessEqual(len(df), self.BATCH_SIZE)
-
- total_length += len(df)
- if len(df) < self.BATCH_SIZE:
- last_batch = True
-
- self._validate_dataframe(df)
- self.assertEqual(total_length, expected_total_length)
-
- def _validate_dataframe(self, df: ta.DataFrame, train: bool = True) -> None:
- if train:
- self.assertEqual(len(df.columns), 3)
- labels = df["labels"]
- for label_val in labels:
- self.assertTrue(
- self.LABEL_VAL_RANGE[0] <= label_val <= self.LABEL_VAL_RANGE[1]
- )
- else:
- self.assertEqual(len(df.columns), 2)
-
- # Validations for both train and test
- dense_features = df["dense_features"]
- for idx in range(self.INT_FEATURE_COUNT):
- int_vals = dense_features[f"int_{idx}"]
- for int_val in int_vals:
- self.assertTrue(
- self.INT_VAL_RANGE[0] <= int_val <= self.INT_VAL_RANGE[1]
- )
-
- sparse_features = df["sparse_features"]
- for idx in range(self.CAT_FEATURE_COUNT):
- cat_vals = sparse_features[f"cat_{idx}"]
- for cat_val in cat_vals:
- # stored as int32
- self.assertTrue(-(2**31) <= cat_val <= 2**31 - 1)
diff --git a/examples/golden_training/tests/test_train_dlrm.py b/examples/golden_training/tests/test_train_dlrm.py
index 691d52937..25e1cc7fc 100644
--- a/examples/golden_training/tests/test_train_dlrm.py
+++ b/examples/golden_training/tests/test_train_dlrm.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/golden_training/train_dlrm.py b/examples/golden_training/train_dlrm.py
index 511b71bce..55514c53d 100644
--- a/examples/golden_training/train_dlrm.py
+++ b/examples/golden_training/train_dlrm.py
@@ -5,12 +5,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
from typing import List, Optional
import torch
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
from torch.utils.data import IterableDataset
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
@@ -25,8 +30,8 @@
from torchrec.models.dlrm import DLRM, DLRMTrain
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
-from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from tqdm import tqdm
@@ -132,7 +137,7 @@ def train(
)
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
)
# Overlap comm/compute/device transfer during training through train_pipeline
diff --git a/examples/golden_training/train_dlrm_data_parallel.py b/examples/golden_training/train_dlrm_data_parallel.py
new file mode 100644
index 000000000..0d57f9617
--- /dev/null
+++ b/examples/golden_training/train_dlrm_data_parallel.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+import os
+from typing import Iterator, List, Optional, Tuple
+
+import torch
+from torch import distributed as dist, nn
+from torch.distributed.elastic.multiprocessing.errors import record
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
+from torch.utils.data import IterableDataset
+from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
+from torchrec.datasets.random import RandomRecDataset
+from torchrec.distributed import TrainPipelineSparseDist
+from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
+from torchrec.distributed.fbgemm_qcomm_codec import (
+ CommType,
+ get_qcomm_codecs_registry,
+ QCommsConfig,
+)
+from torchrec.distributed.model_parallel import DistributedModelParallel
+from torchrec.distributed.planner import EmbeddingShardingPlanner
+from torchrec.distributed.planner.types import ParameterConstraints, ShardingPlan
+from torchrec.models.dlrm import DLRM, DLRMTrain
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
+from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
+from tqdm import tqdm
+
+
+def _get_random_dataset(
+ num_embeddings: int,
+ batch_size: int = 32,
+) -> IterableDataset:
+ return RandomRecDataset(
+ keys=DEFAULT_CAT_NAMES,
+ batch_size=batch_size,
+ hash_size=num_embeddings,
+ ids_per_feature=1,
+ num_dense=len(DEFAULT_INT_NAMES),
+ )
+
+
+@record
+def main() -> None:
+ train()
+
+
+def train(
+ num_embeddings: int = 1024**2,
+ embedding_dim: int = 128,
+ dense_arch_layer_sizes: Optional[List[int]] = None,
+ over_arch_layer_sizes: Optional[List[int]] = None,
+ learning_rate: float = 0.1,
+ num_iterations: int = 1000,
+ qcomm_forward_precision: Optional[CommType] = CommType.FP16,
+ qcomm_backward_precision: Optional[CommType] = CommType.BF16,
+) -> None:
+ """
+ Duplicate of train_dlrm.py, but manually forces one table to be data_parallel.
+ We then optimize this table with RWAdagrad, and optimize the rest of the dense params (i.e dense, inter and over archs) with vanilla Adagrad.
+
+ Constructs and trains a DLRM model (using random dummy data). Each script is run on each process (rank) in SPMD fashion.
+ The embedding layers will be sharded across available ranks
+
+ qcomm_forward_precision: Compression used in forwards pass. FP16 is the recommended usage. INT8 and FP8 are in development, but feel free to try them out.
+ qcomm_backward_precision: Compression used in backwards pass. We recommend using BF16 to ensure training stability.
+
+ The effects of quantized comms will be most apparent in large training jobs across multiple nodes where inter host communication is expensive.
+ """
+ if dense_arch_layer_sizes is None:
+ dense_arch_layer_sizes = [64, embedding_dim]
+ if over_arch_layer_sizes is None:
+ over_arch_layer_sizes = [64, 1]
+
+ # Init process_group , device, rank, backend
+ rank = int(os.environ["LOCAL_RANK"])
+ if torch.cuda.is_available():
+ device: torch.device = torch.device(f"cuda:{rank}")
+ backend = "nccl"
+ torch.cuda.set_device(device)
+ else:
+ device: torch.device = torch.device("cpu")
+ backend = "gloo"
+ dist.init_process_group(backend=backend)
+
+ # Construct DLRM module
+ eb_configs = [
+ EmbeddingBagConfig(
+ name=f"t_{feature_name}",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_embeddings,
+ feature_names=[feature_name],
+ )
+ for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
+ ]
+ dlrm_model = DLRM(
+ embedding_bag_collection=EmbeddingBagCollection(
+ tables=eb_configs, device=torch.device("meta")
+ ),
+ dense_in_features=len(DEFAULT_INT_NAMES),
+ dense_arch_layer_sizes=dense_arch_layer_sizes,
+ over_arch_layer_sizes=over_arch_layer_sizes,
+ dense_device=device,
+ )
+ train_model = DLRMTrain(dlrm_model)
+ # Optional: force some tables (table 10) to be data_parallel through planner ParameterConstraints
+ planner = EmbeddingShardingPlanner(
+ constraints={
+ "t_cat_10": ParameterConstraints(compute_kernels=["dense"]),
+ }
+ )
+ plan: ShardingPlan = planner.collective_plan(train_model)
+
+ apply_optimizer_in_backward(
+ RowWiseAdagrad,
+ train_model.model.sparse_arch.parameters(),
+ {"lr": learning_rate},
+ )
+ qcomm_codecs_registry = (
+ get_qcomm_codecs_registry(
+ qcomms_config=QCommsConfig(
+ # pyre-ignore
+ forward_precision=qcomm_forward_precision,
+ # pyre-ignore
+ backward_precision=qcomm_backward_precision,
+ )
+ )
+ if backend == "nccl"
+ else None
+ )
+ sharder = EmbeddingBagCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry)
+
+ model = DistributedModelParallel(
+ plan=plan,
+ module=train_model,
+ device=device,
+ # pyre-ignore
+ sharders=[sharder],
+ )
+
+ # non fused (dense) embeddings are data parallel and use RowWiseAdagrad optimizer
+ non_fused_embedding_optimizer = KeyedOptimizerWrapper(
+ dict(
+ in_backward_optimizer_filter(
+ model.module.model.sparse_arch.named_parameters()
+ )
+ ),
+ lambda params: RowWiseAdagrad(params, lr=learning_rate),
+ )
+
+ def dense_filter(
+ named_parameters: Iterator[Tuple[str, nn.Parameter]]
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
+ for fqn, param in named_parameters:
+ if "sparse" not in fqn:
+ yield fqn, param
+
+ # DLRM dense (over, inter and dense archs) are optimized with vanilla Adagrad
+ dense_optimizer = KeyedOptimizerWrapper(
+ dict(dense_filter(model.named_parameters())),
+ lambda params: torch.optim.Adagrad(params, lr=learning_rate),
+ )
+ combined_optimizer = CombinedOptimizer(
+ [
+ ("non_fused_embedding_optim", non_fused_embedding_optimizer),
+ ("dense_optim", dense_optimizer),
+ ("fused_embedding_optim", model.fused_optimizer),
+ ]
+ )
+ # Overlap comm/compute/device transfer during training through train_pipeline
+ train_pipeline = TrainPipelineSparseDist(
+ model,
+ combined_optimizer,
+ device,
+ )
+
+ # train model
+ train_iterator = iter(
+ _get_random_dataset(
+ num_embeddings=num_embeddings,
+ )
+ )
+ for _ in tqdm(range(int(num_iterations)), mininterval=5.0):
+ train_pipeline.progress(train_iterator)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md
similarity index 100%
rename from examples/inference/README.md
rename to examples/inference_legacy/README.md
diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py
similarity index 100%
rename from examples/inference/dlrm_client.py
rename to examples/inference_legacy/dlrm_client.py
diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py
similarity index 100%
rename from examples/inference/dlrm_packager.py
rename to examples/inference_legacy/dlrm_packager.py
diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py
similarity index 96%
rename from examples/inference/dlrm_predict.py
rename to examples/inference_legacy/dlrm_predict.py
index c36bb613d..e2bdcef9d 100644
--- a/examples/inference/dlrm_predict.py
+++ b/examples/inference_legacy/dlrm_predict.py
@@ -139,9 +139,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
- num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
- if self.model_config.num_embeddings is None
- else self.model_config.num_embeddings,
+ num_embeddings=(
+ self.model_config.num_embeddings_per_feature[feature_idx]
+ if self.model_config.num_embeddings is None
+ else self.model_config.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py
similarity index 94%
rename from examples/inference/dlrm_predict_single_gpu.py
rename to examples/inference_legacy/dlrm_predict_single_gpu.py
index 753cb0dcc..ba5323247 100644
--- a/examples/inference/dlrm_predict_single_gpu.py
+++ b/examples/inference_legacy/dlrm_predict_single_gpu.py
@@ -50,9 +50,11 @@ def create_predict_module(self, world_size: int) -> torch.nn.Module:
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=self.model_config.embedding_dim,
- num_embeddings=self.model_config.num_embeddings_per_feature[feature_idx]
- if self.model_config.num_embeddings is None
- else self.model_config.num_embeddings,
+ num_embeddings=(
+ self.model_config.num_embeddings_per_feature[feature_idx]
+ if self.model_config.num_embeddings is None
+ else self.model_config.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(
diff --git a/examples/nvt_dataloader/nvt_binary_dataloader.py b/examples/nvt_dataloader/nvt_binary_dataloader.py
index 2286b5610..f88ff1feb 100644
--- a/examples/nvt_dataloader/nvt_binary_dataloader.py
+++ b/examples/nvt_dataloader/nvt_binary_dataloader.py
@@ -94,7 +94,8 @@ def __getitem__(self, idx: int):
"""Numerical features are returned in the order they appear in the channel spec section
For performance reasons, this is required to be the order they are saved in, as specified
by the relevant chunk in source spec.
- Categorical features are returned in the order they appear in the channel spec section"""
+ Categorical features are returned in the order they appear in the channel spec section
+ """
if idx >= self._num_entries:
raise IndexError()
diff --git a/examples/nvt_dataloader/train_torchrec.py b/examples/nvt_dataloader/train_torchrec.py
index ddf53c89b..abdf3a67d 100644
--- a/examples/nvt_dataloader/train_torchrec.py
+++ b/examples/nvt_dataloader/train_torchrec.py
@@ -19,8 +19,6 @@
import torchrec
import torchrec.distributed as trec_dist
import torchrec.optim as trec_optim
-
-from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
from nvt_binary_dataloader import NvtBinaryDataloader
from pyre_extensions import none_throws
from torchrec import EmbeddingBagCollection
@@ -40,6 +38,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
def parse_args(argv: List[str]) -> argparse.Namespace:
@@ -209,9 +208,11 @@ def main(argv: List[str]):
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
- num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
- if num_embeddings_per_feature is not None
- else args.num_embeddings,
+ num_embeddings=(
+ none_throws(num_embeddings_per_feature)[feature_idx]
+ if num_embeddings_per_feature is not None
+ else args.num_embeddings
+ ),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
@@ -233,9 +234,9 @@ def main(argv: List[str]):
train_model = fuse_embedding_optimizer(
train_model,
- optimizer_type=torchrec.optim.RowWiseAdagrad
- if args.adagrad
- else torch.optim.SGD,
+ optimizer_type=(
+ torchrec.optim.RowWiseAdagrad if args.adagrad else torch.optim.SGD
+ ),
optimizer_kwargs={"learning_rate": args.learning_rate},
device=torch.device("meta"),
)
@@ -270,10 +271,12 @@ def main(argv: List[str]):
)
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
- lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
- if args.adagrad
- else torch.optim.SGD(params, lr=args.learning_rate),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
+ lambda params: (
+ torch.optim.Adagrad(params, lr=args.learning_rate)
+ if args.adagrad
+ else torch.optim.SGD(params, lr=args.learning_rate)
+ ),
)
opt = trec_optim.keyed.CombinedOptimizer(
diff --git a/examples/prediction/README.md b/examples/prediction/README.md
new file mode 100644
index 000000000..9a3fb58c7
--- /dev/null
+++ b/examples/prediction/README.md
@@ -0,0 +1,107 @@
+# DLRM Prediction Example
+
+This example demonstrates how to use a Deep Learning Recommendation Model (DLRM) for making predictions using TorchRec capabilities. The code includes:
+
+1. A DLRM implementation using TorchRec's EmbeddingBagCollection and KeyedJaggedTensor
+2. Training with random data
+3. Evaluation
+4. Making sample predictions
+
+## TorchRec Integration
+
+This implementation has been updated to use TorchRec's capabilities:
+- Uses `KeyedJaggedTensor` for sparse features
+- Uses `EmbeddingBagCollection` for embedding tables
+- Follows the DLRM architecture as described in the paper: https://arxiv.org/abs/1906.00091
+
+The example demonstrates how to leverage TorchRec's efficient sparse feature handling for recommendation models.
+
+## Dependencies
+
+Install the required dependencies:
+
+```bash
+# Install PyTorch
+pip install torch torchvision
+
+# Install NumPy
+pip install numpy
+
+# Install TorchRec
+pip install torchrec
+```
+
+**Important**: This implementation now requires torchrec to run, as it uses TorchRec's specialized modules for recommendation systems.
+
+## Running the Example Locally
+
+1. Download the `predict_using_torchrec.py` file to your local machine.
+
+2. Run the example:
+
+```bash
+python3 predict_using_torchrec.py
+```
+
+3. If you're using a different Python environment:
+
+```bash
+# For conda environments
+conda activate your_environment_name
+python predict_using_torchrec.py
+
+# For virtual environments
+source your_venv/bin/activate
+python predict_using_torchrec.py
+```
+
+## What to Expect
+
+When you run the example, you'll see:
+
+1. Training progress for 10 epochs with loss and learning rate information
+2. Evaluation results showing MSE and RMSE metrics
+3. Sample predictions for a specific user on multiple items
+
+## Implementation Details
+
+This example uses TorchRec's capabilities to implement a DLRM model that:
+
+- Takes dense features and sparse features (as KeyedJaggedTensor) as input
+- Processes dense features through a bottom MLP
+- Processes sparse features through EmbeddingBagCollection
+- Computes feature interactions using dot products
+- Processes the interactions through a top MLP
+- Outputs rating predictions on a 0-5 scale
+
+The implementation demonstrates how to use TorchRec's specialized modules for recommendation systems, making it more efficient and scalable than a custom implementation.
+
+## Key TorchRec Components Used
+
+1. **KeyedJaggedTensor**: Efficiently represents sparse features with variable lengths
+2. **EmbeddingBagConfig**: Configures embedding tables with parameters like dimensions and feature names
+3. **EmbeddingBagCollection**: Manages multiple embedding tables for different categorical features
+
+## Troubleshooting
+
+If you encounter any issues:
+
+1. **Python version**: This code has been tested with Python 3.8+. Make sure you're using a compatible version.
+
+2. **PyTorch and TorchRec installation**: If you have issues with PyTorch or TorchRec, try installing specific versions:
+ ```bash
+ pip install torch==2.0.0 torchvision==0.15.0
+ pip install torchrec==0.5.0
+ ```
+
+3. **Memory issues**: If you run out of memory, try reducing the batch size by modifying this line in the code:
+ ```python
+ batch_size = 256 # Try a smaller value like 64 or 32
+ ```
+
+4. **CPU vs GPU**: The code automatically uses CUDA if available. To force CPU usage, modify:
+ ```python
+ device = torch.device("cpu")
+ ```
+
+5. **TorchRec compatibility**: If you encounter compatibility issues with TorchRec, make sure you're using compatible versions of PyTorch and TorchRec.
diff --git a/examples/prediction/__init__.py b/examples/prediction/__init__.py
new file mode 100644
index 000000000..ea13082f2
--- /dev/null
+++ b/examples/prediction/__init__.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+
+def main() -> None:
+ """DOC_STRING"""
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/prediction/predict_using_torchrec.py b/examples/prediction/predict_using_torchrec.py
new file mode 100644
index 000000000..ba212b25b
--- /dev/null
+++ b/examples/prediction/predict_using_torchrec.py
@@ -0,0 +1,699 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+import os
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, Dataset
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+
+# TorchRec imports
+from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
+
+
+# Mock dataset for demonstration.
+class RecommendationDataset(Dataset):
+ """
+ A PyTorch Dataset class for generating random user-item interaction data
+ for recommendation systems.
+
+ Attributes:
+ num_samples (int): Number of samples in the dataset.
+ num_users (int): Number of unique users.
+ num_items (int): Number of unique items.
+ user_ids (torch.Tensor): Tensor of user IDs for each sample.
+ item_ids (torch.Tensor): Tensor of item IDs for each sample.
+ ratings (torch.Tensor): Tensor of ratings for each sample.
+ user_categories (torch.Tensor): Tensor of user categories for each sample.
+ item_categories (torch.Tensor): Tensor of item categories for each sample.
+ dense_features (torch.Tensor): Tensor of dense features for each sample.
+ """
+
+ def __init__(
+ self, num_users: int = 1000, num_items: int = 500, num_samples: int = 10000
+ ) -> None:
+ """
+ Initializes the RecommendationDataset with random data.
+
+ Args:
+ num_users (int): Number of unique users.
+ num_items (int): Number of unique items.
+ num_samples (int): Number of samples to generate.
+ """
+ self.num_samples: int = num_samples
+ self.num_users: int = num_users
+ self.num_items: int = num_items
+
+ # Generate random user-item interactions
+ self.user_ids: torch.Tensor = torch.randint(0, num_users, (num_samples,))
+ self.item_ids: torch.Tensor = torch.randint(0, num_items, (num_samples,))
+
+ # Generate random ratings (0-5)
+ self.ratings: torch.Tensor = torch.randint(0, 6, (num_samples,)).float()
+
+ # Generate some categorical features
+ self.user_categories: torch.Tensor = torch.randint(
+ 0, 10, (num_samples,)
+ ) # 10 user categories
+ self.item_categories: torch.Tensor = torch.randint(
+ 0, 20, (num_samples,)
+ ) # 20 item categories
+
+ # Generate dense features (normalized to [0, 1])
+ self.dense_features: torch.Tensor = torch.rand(
+ num_samples, 4
+ ) # 4 dense features
+
+ def __len__(self) -> int:
+ """
+ Returns the number of samples in the dataset.
+
+ Returns:
+ int: Number of samples.
+ """
+ return self.num_samples
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ """
+ Retrieves a sample from the dataset at the specified index.
+
+ Args:
+ idx (int): Index of the sample to retrieve.
+
+ Returns:
+ dict: A dictionary containing user_id, item_id, user_category,
+ item_category, dense_features, and rating for the sample.
+ """
+ return {
+ "user_id": self.user_ids[idx],
+ "item_id": self.item_ids[idx],
+ "user_category": self.user_categories[idx],
+ "item_category": self.item_categories[idx],
+ "dense_features": self.dense_features[idx],
+ "rating": self.ratings[idx],
+ }
+
+
+# TorchRec DLRM model using KeyedJaggedTensor and EmbeddingBagCollection
+class TorchRecDLRM(nn.Module):
+ """
+ A DLRM model implementation using TorchRec's EmbeddingBagCollection and KeyedJaggedTensor.
+ This model follows the architecture described in the DLRM paper:
+ https://arxiv.org/abs/1906.00091
+
+ Args:
+ embedding_bag_collection: EmbeddingBagCollection for sparse features
+ dense_in_features: Number of dense features
+ dense_arch_layer_sizes: Layer sizes for the dense (bottom) MLP
+ over_arch_layer_sizes: Layer sizes for the over (top) MLP
+
+ Example:
+ ```
+ # Create embedding bag configs
+ eb_configs = [
+ EmbeddingBagConfig(
+ name="user_id",
+ embedding_dim=64,
+ num_embeddings=1000,
+ feature_names=["user_id"],
+ ),
+ EmbeddingBagConfig(
+ name="item_id",
+ embedding_dim=64,
+ num_embeddings=500,
+ feature_names=["item_id"],
+ ),
+ ]
+
+ # Create EmbeddingBagCollection
+ ebc = EmbeddingBagCollection(
+ tables=eb_configs,
+ device=torch.device("cpu"),
+ )
+
+ # Create TorchRecDLRM model
+ model = TorchRecDLRM(
+ embedding_bag_collection=ebc,
+ dense_in_features=4,
+ dense_arch_layer_sizes=[32, 64],
+ over_arch_layer_sizes=[128, 64, 1],
+ )
+
+ # Forward pass
+ batch_size = 2
+ dense_features = torch.rand(batch_size, 4)
+
+ # Create KeyedJaggedTensor for sparse features
+ values = torch.tensor([0, 1, 2, 3])
+ lengths = torch.ones(4, dtype=torch.int32)
+ sparse_features = KeyedJaggedTensor(
+ keys=["user_id", "item_id"],
+ values=values,
+ lengths=lengths,
+ )
+
+ # Get predictions
+ logits = model(dense_features, sparse_features)
+ ```
+ """
+
+ def __init__(
+ self,
+ embedding_bag_collection: EmbeddingBagCollection,
+ dense_in_features: int = 4,
+ dense_arch_layer_sizes: Optional[List[int]] = None,
+ over_arch_layer_sizes: Optional[List[int]] = None,
+ ) -> None:
+ super().__init__()
+
+ if dense_arch_layer_sizes is None:
+ dense_arch_layer_sizes = [32, 64]
+
+ if over_arch_layer_sizes is None:
+ over_arch_layer_sizes = [128, 64, 1]
+
+ self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection
+
+ # Get embedding dimension from the first embedding table
+ embedding_dim = self.embedding_bag_collection.embedding_bag_configs()[
+ 0
+ ].embedding_dim
+
+ # Dense arch (bottom MLP)
+ layers: List[nn.Module] = []
+ input_dim = dense_in_features
+ for output_dim in dense_arch_layer_sizes:
+ layers.append(nn.Linear(input_dim, output_dim))
+ layers.append(nn.ReLU())
+ input_dim = output_dim
+ self.dense_arch: nn.Sequential = nn.Sequential(*layers)
+
+ # Feature interaction: dot product of all pairs
+ num_sparse_features = len(self.embedding_bag_collection.embedding_bag_configs())
+ num_interactions = num_sparse_features + 1 # +1 for dense features
+ num_pairs = (num_interactions * (num_interactions - 1)) // 2
+
+ # Over arch (top MLP)
+ over_input_dim = embedding_dim + num_pairs
+ over_layers: List[nn.Module] = []
+ input_dim = over_input_dim
+ for i, output_dim in enumerate(over_arch_layer_sizes):
+ over_layers.append(nn.Linear(input_dim, output_dim))
+ if i < len(over_arch_layer_sizes) - 1:
+ over_layers.append(nn.ReLU())
+ input_dim = output_dim
+ self.over_arch: nn.Sequential = nn.Sequential(*over_layers)
+
+ def forward(
+ self, dense_features: torch.Tensor, sparse_features: KeyedJaggedTensor
+ ) -> torch.Tensor:
+ """
+ Forward pass of the DLRM model.
+
+ Args:
+ dense_features: Dense input features
+ sparse_features: Sparse input features as KeyedJaggedTensor
+
+ Returns:
+ torch.Tensor: Model output logits
+ """
+ # Process dense features
+ dense_output = self.dense_arch(dense_features)
+
+ # Process sparse features
+ sparse_output = self.embedding_bag_collection(sparse_features)
+
+ # Get embeddings as a list
+ embeddings = [sparse_output[f] for f in sparse_output.keys()]
+
+ # Feature interaction
+ all_features = [dense_output] + embeddings
+ interactions = []
+
+ # Add original dense output
+ interactions.append(dense_output)
+
+ # Compute pairwise dot products
+ for i in range(len(all_features)):
+ for j in range(i + 1, len(all_features)):
+ dot_product = torch.sum(
+ all_features[i] * all_features[j], dim=1, keepdim=True
+ )
+ interactions.append(dot_product)
+
+ # Concatenate all interactions
+ interaction_output = torch.cat(interactions, dim=1)
+
+ # Over arch
+ logits = self.over_arch(interaction_output)
+
+ return logits
+
+
+# DLRM wrapper for rating prediction
+class DLRMRatingWrapper(nn.Module):
+ """
+ Wrapper for DLRM model to scale the output to [0, 5] for rating prediction.
+
+ Args:
+ dlrm_model: The DLRM model to wrap
+
+ Example:
+ ```
+ # Create embedding bag configs
+ eb_configs = [
+ EmbeddingBagConfig(
+ name="user_id",
+ embedding_dim=64,
+ num_embeddings=1000,
+ feature_names=["user_id"],
+ ),
+ EmbeddingBagConfig(
+ name="item_id",
+ embedding_dim=64,
+ num_embeddings=500,
+ feature_names=["item_id"],
+ ),
+ ]
+
+ # Create EmbeddingBagCollection
+ ebc = EmbeddingBagCollection(
+ tables=eb_configs,
+ device=torch.device("cpu"),
+ )
+
+ # Create base model
+ base_model = TorchRecDLRM(
+ embedding_bag_collection=ebc,
+ dense_in_features=4,
+ )
+
+ # Create wrapper
+ model_wrapper = DLRMRatingWrapper(base_model)
+
+ # Forward pass
+ batch_size = 2
+ dense_features = torch.rand(batch_size, 4)
+
+ # Create KeyedJaggedTensor for sparse features
+ values = torch.tensor([0, 1, 2, 3])
+ lengths = torch.ones(4, dtype=torch.int32)
+ sparse_features = KeyedJaggedTensor(
+ keys=["user_id", "item_id"],
+ values=values,
+ lengths=lengths,
+ )
+
+ # Get predictions (scaled to 0-5 range)
+ predictions = model_wrapper(dense_features, sparse_features)
+ ```
+ """
+
+ def __init__(self, dlrm_model: nn.Module) -> None:
+ super().__init__()
+ self.model: nn.Module = dlrm_model
+ self.sigmoid: nn.Sigmoid = nn.Sigmoid()
+
+ def forward(
+ self, dense_features: torch.Tensor, sparse_features: KeyedJaggedTensor
+ ) -> torch.Tensor:
+ """
+ Forward pass of the DLRM wrapper.
+
+ Args:
+ dense_features: Dense input features
+ sparse_features: Sparse input features as KeyedJaggedTensor
+
+ Returns:
+ torch.Tensor: Rating prediction scaled to [0, 5]
+ """
+ logits = self.model(dense_features, sparse_features)
+ # Scale output to [0, 5] for rating prediction
+ return self.sigmoid(logits.squeeze()) * 5.0
+
+
+def create_kjt_from_batch(
+ batch: Dict[str, torch.Tensor], device: torch.device
+) -> KeyedJaggedTensor:
+ """
+ Create a KeyedJaggedTensor from a batch of data.
+
+ Args:
+ batch: Batch of data containing categorical features
+ device: Device to place the KeyedJaggedTensor on
+
+ Returns:
+ KeyedJaggedTensor: Sparse features in KeyedJaggedTensor format
+ """
+ # For this example, each categorical feature has exactly one value per sample
+ # So lengths are all 1s
+ batch_size = batch["user_id"].size(0)
+ lengths = torch.ones(batch_size * 4, dtype=torch.int32, device=device)
+
+ # Concatenate all values
+ values = torch.cat(
+ [
+ batch["user_id"],
+ batch["item_id"],
+ batch["user_category"],
+ batch["item_category"],
+ ]
+ ).to(device)
+
+ # Create KeyedJaggedTensor
+ return KeyedJaggedTensor(
+ keys=["user_id", "item_id", "user_category", "item_category"],
+ values=values,
+ lengths=lengths,
+ )
+
+
+def train_dlrm_model() -> Tuple[DLRMRatingWrapper, str]:
+ """
+ Trains the Deep Learning Recommendation Model (DLRM) using a specified dataset and hyperparameters.
+
+ Returns:
+ tuple: A tuple containing the trained DLRM model and the model filename.
+ """
+ # Set device
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Training DLRM on device: {device}")
+
+ # Hyperparameters
+ num_users = 1000
+ num_items = 500
+ num_user_categories = 10
+ num_item_categories = 20
+ embedding_dim = 64
+ batch_size = 256
+ learning_rate = 0.001
+ num_epochs = 10
+
+ # Create dataset and dataloader
+ dataset = RecommendationDataset(num_users, num_items, 10000)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ # Create embedding bag configs
+ eb_configs = [
+ EmbeddingBagConfig(
+ name="user_id",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_users,
+ feature_names=["user_id"],
+ ),
+ EmbeddingBagConfig(
+ name="item_id",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_items,
+ feature_names=["item_id"],
+ ),
+ EmbeddingBagConfig(
+ name="user_category",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_user_categories,
+ feature_names=["user_category"],
+ ),
+ EmbeddingBagConfig(
+ name="item_category",
+ embedding_dim=embedding_dim,
+ num_embeddings=num_item_categories,
+ feature_names=["item_category"],
+ ),
+ ]
+
+ # Create EmbeddingBagCollection
+ ebc = EmbeddingBagCollection(
+ tables=eb_configs,
+ device=device,
+ )
+
+ # Create TorchRecDLRM model
+ model = TorchRecDLRM(
+ embedding_bag_collection=ebc,
+ dense_in_features=4,
+ dense_arch_layer_sizes=[32, embedding_dim],
+ over_arch_layer_sizes=[128, 64, 1],
+ ).to(device)
+
+ # Create a wrapper to scale the output to [0, 5] for rating prediction
+ model_wrapper = DLRMRatingWrapper(model).to(device)
+
+ # Loss function and optimizer
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(
+ model_wrapper.parameters(), lr=learning_rate, weight_decay=1e-5
+ )
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
+
+ print(
+ f"DLRM Model parameters: {sum(p.numel() for p in model_wrapper.parameters()):,}"
+ )
+
+ # Training loop
+ model_wrapper.train()
+ for epoch in range(num_epochs):
+ total_loss = 0
+ num_batches = 0
+
+ for batch in dataloader:
+ # Move data to device
+ dense_features = batch["dense_features"].to(device)
+ ratings = batch["rating"].to(device)
+
+ # Create KeyedJaggedTensor for sparse features
+ sparse_features = create_kjt_from_batch(batch, device)
+
+ # Forward pass
+ predictions = model_wrapper(dense_features, sparse_features)
+ loss = criterion(predictions, ratings)
+
+ # Backward pass
+ optimizer.zero_grad()
+ loss.backward()
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model_wrapper.parameters(), max_norm=1.0)
+
+ optimizer.step()
+
+ total_loss += loss.item()
+ num_batches += 1
+
+ # Update learning rate
+ scheduler.step()
+
+ avg_loss = total_loss / num_batches
+ current_lr = optimizer.param_groups[0]["lr"]
+ print(
+ f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, LR: {current_lr:.6f}"
+ )
+
+ print("DLRM Training completed!")
+
+ # Save the model
+ model_filename = "dlrm_model.pth"
+ torch.save(
+ {
+ "model_state_dict": model_wrapper.state_dict(),
+ "embedding_dim": embedding_dim,
+ "num_users": num_users,
+ "num_items": num_items,
+ "num_user_categories": num_user_categories,
+ "num_item_categories": num_item_categories,
+ "dense_in_features": 4,
+ "dense_arch_layer_sizes": [32, embedding_dim],
+ "over_arch_layer_sizes": [128, 64, 1],
+ },
+ model_filename,
+ )
+ print(f"DLRM Model saved as {model_filename}")
+ return model_wrapper, model_filename
+
+
+def evaluate_dlrm_model(
+ model: nn.Module,
+ dataloader: DataLoader[RecommendationDataset],
+ device: torch.device,
+) -> Tuple[float, float]:
+ """
+ Evaluate the DLRM model on a dataset.
+
+ Args:
+ model (nn.Module): The DLRM model to evaluate.
+ dataloader (DataLoader): DataLoader providing the evaluation dataset.
+ device (torch.device): The device to perform evaluation on (CPU or GPU).
+
+ Returns:
+ tuple: A tuple containing the average loss and root mean square error (RMSE).
+ """
+ model.eval()
+ total_loss = 0
+ num_batches = 0
+ criterion = nn.MSELoss()
+
+ with torch.no_grad():
+ for batch in dataloader:
+ # Move data to device
+ dense_features = batch["dense_features"].to(device)
+ ratings = batch["rating"].to(device)
+
+ # Create KeyedJaggedTensor for sparse features
+ sparse_features = create_kjt_from_batch(batch, device)
+
+ # Forward pass
+ predictions = model(dense_features, sparse_features)
+ loss = criterion(predictions, ratings)
+
+ total_loss += loss.item()
+ num_batches += 1
+
+ avg_loss = total_loss / num_batches
+ rmse = np.sqrt(avg_loss)
+ print(f"DLRM Evaluation - MSE: {avg_loss:.4f}, RMSE: {rmse:.4f}")
+ return avg_loss, rmse
+
+
+def make_dlrm_predictions(
+ model: DLRMRatingWrapper,
+ user_id: int,
+ item_ids: List[int],
+ user_category: int,
+ item_categories: List[int],
+ device: torch.device,
+) -> np.ndarray:
+ """Make predictions using DLRM for a user on multiple items"""
+ model.eval()
+ with torch.no_grad():
+ batch_size = len(item_ids)
+
+ # Prepare inputs
+ user_ids_tensor = torch.tensor([user_id] * batch_size).to(device)
+ item_ids_tensor = torch.tensor(item_ids).to(device)
+ user_cats = torch.tensor([user_category] * batch_size).to(device)
+ item_cats = torch.tensor(item_categories).to(device)
+
+ # Generate random dense features for demonstration
+ dense_features = torch.rand(batch_size, 4).to(device)
+
+ # Create KeyedJaggedTensor for sparse features
+ # For this example, each categorical feature has exactly one value per sample
+ lengths = torch.ones(batch_size * 4, dtype=torch.int32, device=device)
+
+ # Concatenate all values
+ values = torch.cat(
+ [
+ user_ids_tensor,
+ item_ids_tensor,
+ user_cats,
+ item_cats,
+ ]
+ )
+
+ # Create KeyedJaggedTensor
+ sparse_features = KeyedJaggedTensor(
+ keys=["user_id", "item_id", "user_category", "item_category"],
+ values=values,
+ lengths=lengths,
+ )
+
+ # Make predictions
+ predictions = model(dense_features, sparse_features)
+
+ # Convert to numpy array and ensure it's a 1D array
+ numpy_predictions = predictions.cpu().numpy()
+ # Flatten in case it's not already 1D
+ return numpy_predictions.flatten()
+
+
+def remove_model_file(model_filename: str) -> None:
+ """
+ Removes the model file if it exists.
+
+ Args:
+ model_filename (str): The filename of the model to be removed.
+ """
+ if os.path.exists(model_filename):
+ try:
+ os.remove(model_filename)
+ print(f"Successfully removed the file: {model_filename}")
+ except PermissionError:
+ print(f"Permission denied: {model_filename}")
+ except Exception as e:
+ print(f"An error occurred while trying to remove the file: {e}")
+ else:
+ print(f"File does not exist: {model_filename}")
+
+
+def main() -> None:
+ """
+ Main function to orchestrate the training, evaluation, and prediction
+ processes of the Deep Learning Recommendation Model (DLRM).
+
+ This function performs the following steps:
+ 1. Trains the DLRM model using a specified dataset and hyperparameters.
+ 2. Evaluates the trained model on a separate evaluation dataset.
+ 3. Makes sample predictions for a specific user on multiple items.
+ 4. Cleans up the model after training to free up resources.
+ """
+ print("Starting DLRM (Deep Learning Recommendation Model) Training...")
+
+ try:
+ # Train the DLRM model
+ print("Starting training...")
+ trained_model, model_filename = train_dlrm_model()
+ print("Training completed successfully!")
+ except Exception as e:
+ print(f"Error during training: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return
+
+ # Create evaluation dataset
+ eval_dataset = RecommendationDataset(1000, 500, 2000)
+ eval_dataloader = DataLoader(eval_dataset, batch_size=256, shuffle=False)
+
+ # Evaluate the model
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print("\nEvaluating DLRM model...")
+ evaluate_dlrm_model(trained_model, eval_dataloader, device)
+
+ # Example prediction
+ print("\nMaking sample DLRM predictions...")
+ sample_user_id = 42
+ sample_items = [10, 25, 50, 100, 200]
+ sample_user_cat = 3
+ sample_item_cats = [5, 12, 8, 15, 2]
+
+ predictions = make_dlrm_predictions(
+ trained_model,
+ sample_user_id,
+ sample_items,
+ sample_user_cat,
+ sample_item_cats,
+ device,
+ )
+
+ print(f"DLRM Predictions for user {sample_user_id}:")
+ for item_id, pred in zip(sample_items, predictions):
+ print(f" Item {item_id}: {pred:.2f}")
+
+ # clean the model after training
+ print(f"Cleaning the model {model_filename}")
+ # comment this line if you want to keep the model file
+ remove_model_file(model_filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/prediction/test_predict_using_torchrec.py b/examples/prediction/test_predict_using_torchrec.py
new file mode 100644
index 000000000..e8c12467d
--- /dev/null
+++ b/examples/prediction/test_predict_using_torchrec.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+import os
+import sys
+import unittest
+
+import torch
+
+# Add the current directory to sys.path
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+# Import using the full module path
+from torchrec.github.examples.prediction.predict_using_torchrec import (
+ create_kjt_from_batch,
+ DLRMRatingWrapper,
+ RecommendationDataset,
+ TorchRecDLRM,
+)
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+
+# TorchRec imports
+from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
+
+
+class TestRecommendationDataset(unittest.TestCase):
+ """Test cases for the RecommendationDataset class."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ self.num_users = 100
+ self.num_items = 50
+ self.num_samples = 1000
+ self.dataset = RecommendationDataset(
+ num_users=self.num_users,
+ num_items=self.num_items,
+ num_samples=self.num_samples,
+ )
+
+ def test_init(self) -> None:
+ """Test initialization of RecommendationDataset."""
+ self.assertEqual(self.dataset.num_users, self.num_users)
+ self.assertEqual(self.dataset.num_items, self.num_items)
+ self.assertEqual(self.dataset.num_samples, self.num_samples)
+
+ # Check tensor shapes
+ self.assertEqual(self.dataset.user_ids.shape, (self.num_samples,))
+ self.assertEqual(self.dataset.item_ids.shape, (self.num_samples,))
+ self.assertEqual(self.dataset.ratings.shape, (self.num_samples,))
+ self.assertEqual(self.dataset.user_categories.shape, (self.num_samples,))
+ self.assertEqual(self.dataset.item_categories.shape, (self.num_samples,))
+ self.assertEqual(self.dataset.dense_features.shape, (self.num_samples, 4))
+
+ def test_len(self) -> None:
+ """Test __len__ method."""
+ self.assertEqual(len(self.dataset), self.num_samples)
+
+ def test_getitem(self) -> None:
+ """Test __getitem__ method."""
+ item = self.dataset[0]
+ self.assertIsInstance(item, dict)
+ self.assertIn("user_id", item)
+ self.assertIn("item_id", item)
+ self.assertIn("user_category", item)
+ self.assertIn("item_category", item)
+ self.assertIn("dense_features", item)
+ self.assertIn("rating", item)
+
+ # Check tensor shapes for a single item
+ self.assertEqual(item["dense_features"].shape, (4,))
+ self.assertEqual(item["user_category"].shape, ())
+
+
+class TestTorchRecDLRM(unittest.TestCase):
+ """Test cases for the TorchRecDLRM class."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ self.embedding_dim = 32
+ self.dense_in_features = 4
+ self.dense_arch_layer_sizes = [16, self.embedding_dim]
+ self.over_arch_layer_sizes = [64, 32, 1]
+
+ # Create embedding bag configs
+ self.num_users = 100
+ self.num_items = 50
+ self.num_user_categories = 10
+ self.num_item_categories = 20
+
+ self.eb_configs = [
+ EmbeddingBagConfig(
+ name="user_id",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_users,
+ feature_names=["user_id"],
+ ),
+ EmbeddingBagConfig(
+ name="item_id",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_items,
+ feature_names=["item_id"],
+ ),
+ EmbeddingBagConfig(
+ name="user_category",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_user_categories,
+ feature_names=["user_category"],
+ ),
+ EmbeddingBagConfig(
+ name="item_category",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_item_categories,
+ feature_names=["item_category"],
+ ),
+ ]
+
+ # Create EmbeddingBagCollection
+ self.device = torch.device("cpu")
+ self.ebc = EmbeddingBagCollection(
+ tables=self.eb_configs,
+ device=self.device,
+ )
+
+ # Create model
+ self.model = TorchRecDLRM(
+ embedding_bag_collection=self.ebc,
+ dense_in_features=self.dense_in_features,
+ dense_arch_layer_sizes=self.dense_arch_layer_sizes,
+ over_arch_layer_sizes=self.over_arch_layer_sizes,
+ )
+
+ # Test data
+ self.batch_size = 8
+ self.dense_features = torch.rand(self.batch_size, self.dense_in_features)
+
+ # Create batch data
+ self.batch = {
+ "user_id": torch.randint(0, self.num_users, (self.batch_size,)),
+ "item_id": torch.randint(0, self.num_items, (self.batch_size,)),
+ "user_category": torch.randint(
+ 0, self.num_user_categories, (self.batch_size,)
+ ),
+ "item_category": torch.randint(
+ 0, self.num_item_categories, (self.batch_size,)
+ ),
+ }
+
+ # Create KeyedJaggedTensor for sparse features
+ self.sparse_features = create_kjt_from_batch(self.batch, self.device)
+
+ def test_init(self) -> None:
+ """Test initialization of TorchRecDLRM."""
+ self.assertIsInstance(
+ self.model.embedding_bag_collection, EmbeddingBagCollection
+ )
+ self.assertEqual(
+ len(self.model.embedding_bag_collection.embedding_bag_configs()), 4
+ )
+
+ # Check embedding dimensions
+ for config in self.model.embedding_bag_collection.embedding_bag_configs():
+ self.assertEqual(config.embedding_dim, self.embedding_dim)
+
+ def test_forward(self) -> None:
+ """Test forward pass of TorchRecDLRM."""
+ # Run forward pass
+ output = self.model(self.dense_features, self.sparse_features)
+
+ # Check output shape
+ self.assertEqual(output.shape, (self.batch_size, 1))
+
+ # Check output is not NaN
+ self.assertFalse(torch.isnan(output).any())
+
+
+class TestDLRMRatingWrapper(unittest.TestCase):
+ """Test cases for the DLRMRatingWrapper class."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ self.embedding_dim = 32
+ self.dense_in_features = 4
+ self.device = torch.device("cpu")
+
+ # Create embedding bag configs
+ self.num_users = 100
+ self.num_items = 50
+ self.num_user_categories = 10
+ self.num_item_categories = 20
+
+ self.eb_configs = [
+ EmbeddingBagConfig(
+ name="user_id",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_users,
+ feature_names=["user_id"],
+ ),
+ EmbeddingBagConfig(
+ name="item_id",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_items,
+ feature_names=["item_id"],
+ ),
+ EmbeddingBagConfig(
+ name="user_category",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_user_categories,
+ feature_names=["user_category"],
+ ),
+ EmbeddingBagConfig(
+ name="item_category",
+ embedding_dim=self.embedding_dim,
+ num_embeddings=self.num_item_categories,
+ feature_names=["item_category"],
+ ),
+ ]
+
+ # Create EmbeddingBagCollection
+ self.ebc = EmbeddingBagCollection(
+ tables=self.eb_configs,
+ device=self.device,
+ )
+
+ # Create base model with the same architecture as in TestTorchRecDLRM
+ self.base_model = TorchRecDLRM(
+ embedding_bag_collection=self.ebc,
+ dense_in_features=self.dense_in_features,
+ dense_arch_layer_sizes=[16, self.embedding_dim],
+ over_arch_layer_sizes=[64, 32, 1],
+ )
+
+ # Create wrapper
+ self.model_wrapper = DLRMRatingWrapper(self.base_model).to(self.device)
+
+ # Test data
+ self.batch_size = 8
+ self.dense_features = torch.rand(self.batch_size, self.dense_in_features)
+
+ # Create batch data
+ self.batch = {
+ "user_id": torch.randint(0, self.num_users, (self.batch_size,)),
+ "item_id": torch.randint(0, self.num_items, (self.batch_size,)),
+ "user_category": torch.randint(
+ 0, self.num_user_categories, (self.batch_size,)
+ ),
+ "item_category": torch.randint(
+ 0, self.num_item_categories, (self.batch_size,)
+ ),
+ }
+
+ # Create KeyedJaggedTensor for sparse features
+ self.sparse_features = create_kjt_from_batch(self.batch, self.device)
+
+ def test_forward(self) -> None:
+ """Test forward pass of DLRMRatingWrapper."""
+ # Run forward pass
+ output = self.model_wrapper(self.dense_features, self.sparse_features)
+
+ # Check output has the correct batch size
+ self.assertEqual(output.numel(), self.batch_size)
+
+ # Ensure output is 1D or 2D with second dimension of 1
+ if output.dim() == 2:
+ self.assertEqual(output.shape[1], 1)
+ # Squeeze to make it 1D for further checks
+ output = output.squeeze()
+
+ # Check output is not NaN
+ self.assertFalse(torch.isnan(output).any())
+
+ # Since we're using random initialization, we can't guarantee exact output range
+ # Just check that the output is finite
+ self.assertTrue(torch.isfinite(output).all())
+
+
+class TestCreateKJTFromBatch(unittest.TestCase):
+ """Test cases for the create_kjt_from_batch function."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ self.batch_size = 8
+ self.device = torch.device("cpu")
+
+ # Create batch data
+ self.batch = {
+ "user_id": torch.randint(0, 100, (self.batch_size,)),
+ "item_id": torch.randint(0, 50, (self.batch_size,)),
+ "user_category": torch.randint(0, 10, (self.batch_size,)),
+ "item_category": torch.randint(0, 20, (self.batch_size,)),
+ }
+
+ def test_create_kjt_from_batch(self) -> None:
+ """Test create_kjt_from_batch function."""
+ kjt = create_kjt_from_batch(self.batch, self.device)
+
+ # Check that it's a KeyedJaggedTensor
+ self.assertIsInstance(kjt, KeyedJaggedTensor)
+
+ # Check keys - use set comparison to avoid order issues
+ self.assertEqual(
+ set(kjt.keys()), {"user_id", "item_id", "user_category", "item_category"}
+ )
+
+ # Check values length
+ self.assertEqual(kjt.values().shape[0], self.batch_size * 4)
+
+ # Check lengths
+ self.assertEqual(kjt.lengths().shape[0], self.batch_size * 4)
+ self.assertTrue((kjt.lengths() == 1).all())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/examples/ray/train_torchrec.py b/examples/ray/train_torchrec.py
index 2406c6796..4ee5fdeb9 100644
--- a/examples/ray/train_torchrec.py
+++ b/examples/ray/train_torchrec.py
@@ -24,6 +24,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
from tqdm import tqdm
@@ -110,7 +111,7 @@ def train(
# Overlap comm/compute/device transfer during training through train_pipeline
non_fused_optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
)
train_pipeline = TrainPipelineSparseDist(
diff --git a/examples/retrieval/data/dataloader.py b/examples/retrieval/data/dataloader.py
index 5727910c7..39e2ca1c4 100644
--- a/examples/retrieval/data/dataloader.py
+++ b/examples/retrieval/data/dataloader.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from torch.utils.data import DataLoader
from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES
from torchrec.datasets.random import RandomRecDataset
diff --git a/examples/retrieval/knn_index.py b/examples/retrieval/knn_index.py
index a1323f1ca..9db4a6d7d 100644
--- a/examples/retrieval/knn_index.py
+++ b/examples/retrieval/knn_index.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Optional, Union
import faiss # @manual=//faiss/python:pyfaiss_gpu
diff --git a/examples/retrieval/modules/tests/test_two_tower.py b/examples/retrieval/modules/tests/test_two_tower.py
index 9ac08b10f..237ae7ee8 100644
--- a/examples/retrieval/modules/tests/test_two_tower.py
+++ b/examples/retrieval/modules/tests/test_two_tower.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import unittest
diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py
index 5ee2b22af..704e02b63 100644
--- a/examples/retrieval/modules/two_tower.py
+++ b/examples/retrieval/modules/two_tower.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import Any, List, Mapping, Optional, OrderedDict, Tuple, Union
import faiss # @manual=//faiss/python:pyfaiss_gpu
@@ -71,12 +73,12 @@ def __init__(
embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
0
].embedding_dim
- self._feature_names_query: List[
- str
- ] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
- self._candidate_feature_names: List[
- str
- ] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
+ self._feature_names_query: List[str] = (
+ embedding_bag_collection.embedding_bag_configs()[0].feature_names
+ )
+ self._candidate_feature_names: List[str] = (
+ embedding_bag_collection.embedding_bag_configs()[1].feature_names
+ )
self.ebc = embedding_bag_collection
self.query_proj = MLP(
in_size=embedding_dim, layer_sizes=layer_sizes, device=device
@@ -174,6 +176,7 @@ def __init__(
layer_sizes: List[int],
k: int,
device: Optional[torch.device] = None,
+ dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim
@@ -186,10 +189,16 @@ def __init__(
self.query_ebc = query_ebc
self.candidate_ebc = candidate_ebc
self.query_proj = MLP(
- in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
+ in_size=self.embedding_dim,
+ layer_sizes=layer_sizes,
+ device=device,
+ dtype=dtype,
)
self.candidate_proj = MLP(
- in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
+ in_size=self.embedding_dim,
+ layer_sizes=layer_sizes,
+ device=device,
+ dtype=dtype,
)
self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index
self.k = k
@@ -212,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
candidates = torch.empty(
(batch_size, self.k), device=self.device, dtype=torch.int64
)
+ query_embedding = query_embedding.to(torch.float32) # required by faiss
self.faiss_index.search(query_embedding, self.k, distances, candidates)
# candidate lookup
@@ -227,5 +237,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
# return logit (dot product)
return (query_embedding * candidate_embedding).sum(dim=1).squeeze()
+ # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
+ # inconsistently.
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool) -> None:
super().load_state_dict(state_dict, strict)
diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py
index eef1ef455..77952e0fa 100644
--- a/examples/retrieval/tests/test_two_tower_retrieval.py
+++ b/examples/retrieval/tests/test_two_tower_retrieval.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import unittest
import torch
@@ -19,11 +21,12 @@ class InferTest(unittest.TestCase):
@skip_if_asan
# pyre-ignore[56]
@unittest.skipIf(
- not torch.cuda.is_available(),
- "this test requires a GPU",
+ torch.cuda.device_count() <= 1,
+ "Not enough GPUs, this test requires at least two GPUs",
)
def test_infer_function(self) -> None:
infer(
embedding_dim=16,
layer_sizes=[16],
+ world_size=2,
)
diff --git a/examples/retrieval/tests/test_two_tower_train.py b/examples/retrieval/tests/test_two_tower_train.py
index 2b78b8426..8db7684cf 100644
--- a/examples/retrieval/tests/test_two_tower_train.py
+++ b/examples/retrieval/tests/test_two_tower_train.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import os
import tempfile
import unittest
diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py
index b1b4ccb49..08a2de4b6 100644
--- a/examples/retrieval/two_tower_retrieval.py
+++ b/examples/retrieval/two_tower_retrieval.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
from typing import List, Optional
import click
@@ -18,16 +20,13 @@
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ShardingEnv, ShardingType
-from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
# OSS import
try:
- # pyre-ignore[21]
- # @manual=//torchrec/github/examples/retrieval/data:dataloader
- from data.dataloader import get_dataloader
# pyre-ignore[21]
# @manual=//torchrec/github/examples/retrieval:knn_index
@@ -78,6 +77,7 @@ def infer(
faiss_device_idx: int = 0,
batch_size: int = 32,
load_dir: Optional[str] = None,
+ world_size: int = 2,
) -> None:
"""
Loads the serialized model and FAISS index from `two_tower_train.py`.
@@ -116,6 +116,7 @@ def infer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
feature_names=[feature_name],
+ data_type=DataType.FP16,
)
ebcs.append(
EmbeddingBagCollection(
@@ -135,7 +136,7 @@ def infer(
# pyre-ignore[16]
faiss.read_index(f"{load_dir}/faiss.index"),
)
- two_tower_sd = torch.load(f"{load_dir}/model.pt")
+ two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True)
retrieval_sd = convert_TwoTower_to_TwoTowerRetrieval(
two_tower_sd,
[f"t_{two_tower_column_names[0]}"],
@@ -156,7 +157,9 @@ def infer(
index.train(embeddings)
index.add(embeddings)
- retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device)
+ retrieval_model = TwoTowerRetrieval(
+ index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
+ )
constraints = {}
for feature_name in two_tower_column_names:
@@ -166,13 +169,16 @@ def infer(
)
quant_model = trec_infer.modules.quantize_embeddings(
- retrieval_model, dtype=torch.qint8, inplace=True
+ retrieval_model,
+ dtype=torch.qint8,
+ inplace=True,
+ output_dtype=torch.float16,
)
dmp = DistributedModelParallel(
module=quant_model,
device=device,
- env=ShardingEnv.from_local(world_size=2, rank=model_device_idx),
+ env=ShardingEnv.from_local(world_size=world_size, rank=model_device_idx),
init_data_parallel=False,
)
if retrieval_sd is not None:
diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py
index 141c1ca7d..5772a9069 100644
--- a/examples/retrieval/two_tower_train.py
+++ b/examples/retrieval/two_tower_train.py
@@ -5,25 +5,26 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
#!/usr/bin/env python3
import os
-from typing import cast, List, Optional
+from typing import List, Optional
import click
import faiss # @manual=//faiss/python:pyfaiss_gpu
import faiss.contrib.torch_utils # @manual=//faiss/contrib:faiss_contrib_gpu
import torch
-from fbgemm_gpu.split_embedding_configs import EmbOptimType
-from torch import distributed as dist, nn
+from torch import distributed as dist
+from torch.distributed.optim import (
+ _apply_optimizer_in_backward as apply_optimizer_in_backward,
+)
from torchrec import inference as trec_infer
from torchrec.datasets.movielens import DEFAULT_RATINGS_COLUMN_NAMES
from torchrec.distributed import TrainPipelineSparseDist
-from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
-from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
-from torchrec.distributed.types import ModuleSharder
from torchrec.inference.state_dict_transform import (
state_dict_gather,
state_dict_to_device,
@@ -31,6 +32,7 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -108,7 +110,6 @@ def train(
layer_sizes = [128, 64]
rank = int(os.environ["LOCAL_RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
@@ -138,36 +139,14 @@ def train(
device=device,
)
two_tower_train_task = TwoTowerTrainTask(two_tower_model)
-
- fused_params = {
- "learning_rate": learning_rate,
- "optimizer": EmbOptimType.ROWWISE_ADAGRAD,
- }
- sharders = cast(
- List[ModuleSharder[nn.Module]],
- [EmbeddingBagCollectionSharder(fused_params=fused_params)],
- )
-
- # TODO: move pg to the EmbeddingShardingPlanner (out of collective_plan) and make optional
- # TODO: make Topology optional argument to EmbeddingShardingPlanner
- # TODO: give collective_plan a default sharders
- # TODO: once this is done, move defaults out of DMP and just get from ShardingPlan (eg _sharding_map should not exist - just use the plan)
- plan = EmbeddingShardingPlanner(
- topology=Topology(
- world_size=world_size,
- compute_device=device.type,
- ),
- ).collective_plan(
- module=two_tower_model,
- sharders=sharders,
- # pyre-fixme[6]: For 3rd param expected `ProcessGroup` but got
- # `Optional[ProcessGroup]`.
- pg=dist.GroupMember.WORLD,
+ apply_optimizer_in_backward(
+ RowWiseAdagrad,
+ two_tower_train_task.two_tower.ebc.parameters(),
+ {"lr": learning_rate},
)
model = DistributedModelParallel(
module=two_tower_train_task,
device=device,
- plan=plan,
)
optimizer = KeyedOptimizerWrapper(
diff --git a/examples/sharding/sharding.ipynb b/examples/sharding/sharding.ipynb
index b61321a54..805190a20 100644
--- a/examples/sharding/sharding.ipynb
+++ b/examples/sharding/sharding.ipynb
@@ -8,168 +8,239 @@
"source": [
"## **Installation**\n",
"Requirements:\n",
- "- python >= 3.7\n",
+ "- python >= 3.9\n",
+ "- a device 2 GPUs\n",
"\n",
"We highly recommend CUDA when using torchRec. If using CUDA:\n",
- "- cuda >= 11.0\n"
+ "- cuda >= 12.0\n"
]
},
{
"cell_type": "code",
- "source": [
- "# install conda to make installying pytorch with cudatoolkit 11.3 easier. \n",
- "!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*\n",
- "!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh\n",
- "!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local"
- ],
- "metadata": {
- "id": "BB2K68OYUJ_t"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
- "id": "sFYvP95xaAER"
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Akmt4viCo9dz",
+ "outputId": "e008352d-85b6-4713-827f-7a3eeb9ad09b"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.13.2)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.6)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch) (9.1.0.70)\n",
+ "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.0)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.21.5)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.1.105)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cu121\n",
+ "Requirement already satisfied: fbgemm_gpu in /usr/local/lib/python3.10/dist-packages (1.0.0+cu121)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from fbgemm_gpu) (2.2.5)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (1.0.3)\n",
+ "Requirement already satisfied: lightning-utilities>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (0.14.3)\n",
+ "Requirement already satisfied: torch>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.5.1+cu121)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (25.0)\n",
+ "Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.5)\n",
+ "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.7.0->torchmetrics) (4.13.2)\n",
+ "Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from lightning-utilities>=0.7.0->torchmetrics) (59.6.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.13.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (2.21.5)\n",
+ "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.1.0)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.105)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (2024.6.1)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.3.1)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (1.13.1)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (10.3.2.106)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.1.6)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (12.1.0.106)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.3)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.8.1->torchmetrics) (12.1.105)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.8.1->torchmetrics) (1.3.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8.1->torchmetrics) (3.0.2)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cu121\n",
+ "Requirement already satisfied: torchrec in /usr/local/lib/python3.10/dist-packages (1.0.0+cu121)\n",
+ "Requirement already satisfied: fbgemm-gpu in /usr/local/lib/python3.10/dist-packages (from torchrec) (1.0.0+cu121)\n",
+ "Requirement already satisfied: torchmetrics==1.0.3 in /usr/local/lib/python3.10/dist-packages (from torchrec) (1.0.3)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torchrec) (4.66.5)\n",
+ "Requirement already satisfied: pyre-extensions in /usr/local/lib/python3.10/dist-packages (from torchrec) (0.0.31)\n",
+ "Requirement already satisfied: iopath in /usr/local/lib/python3.10/dist-packages (from torchrec) (0.1.9)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchmetrics==1.0.3->torchrec) (25.0)\n",
+ "Requirement already satisfied: torch>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics==1.0.3->torchrec) (2.5.1+cu121)\n",
+ "Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics==1.0.3->torchrec) (2.2.5)\n",
+ "Requirement already satisfied: lightning-utilities>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics==1.0.3->torchrec) (0.14.3)\n",
+ "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from iopath->torchrec) (2.10.1)\n",
+ "Requirement already satisfied: typing-inspect in /usr/local/lib/python3.10/dist-packages (from pyre-extensions->torchrec) (0.9.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pyre-extensions->torchrec) (4.13.2)\n",
+ "Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from lightning-utilities>=0.7.0->torchmetrics==1.0.3->torchrec) (59.6.0)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (11.4.5.107)\n",
+ "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (3.1.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (3.13.1)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (3.3)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (10.3.2.106)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (2024.6.1)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (2.21.5)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics==1.0.3->torchrec) (3.1.6)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.8.1->torchmetrics==1.0.3->torchrec) (12.1.105)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.8.1->torchmetrics==1.0.3->torchrec) (1.3.0)\n",
+ "Requirement already satisfied: mypy-extensions>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from typing-inspect->pyre-extensions->torchrec) (1.0.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8.1->torchmetrics==1.0.3->torchrec) (3.0.2)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
"source": [
- "# install pytorch with cudatoolkit 11.3\n",
- "!sudo conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y"
+ "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n",
+ "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n",
+ "!pip3 install torchmetrics\n",
+ "!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121"
]
},
- {
- "cell_type": "markdown",
- "source": [
- "Installing torchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run "
- ],
- "metadata": {
- "id": "7iY7Uv11mJYK"
- }
- },
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {
- "id": "tUnIw-ZREQJy"
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "pkUJVDAA2dER",
+ "outputId": "2ffbbf14-b6b4-4687-9b0d-4bba85d8c316"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (0.70.18)\n",
+ "Requirement already satisfied: dill>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from multiprocess) (0.4.0)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
"source": [
- "# install torchrec\n",
- "!pip3 install torchrec-nightly"
+ "!pip3 install multiprocess"
]
},
{
"cell_type": "markdown",
- "source": [
- "Install multiprocess which works with ipython to for multi-processing programming within colab"
- ],
- "metadata": {
- "id": "0wLX94Lw_Lml"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "!pip3 install multiprocess"
- ],
"metadata": {
- "id": "HKoKRP-QzRCF"
+ "id": "HWBOrwVSnrNE"
},
- "execution_count": null,
- "outputs": []
+ "source": [
+ "## **Overview**\n",
+ "This tutorial will mainly cover the sharding schemes of embedding tables via `EmbeddingPlanner` and `DistributedModelParallel` API and explore the benefits of different sharding schemes for the embedding tables by explicitly configuring them."
+ ]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "b6EHgotRXFQh"
+ "id": "udsN6PlUo1zF"
},
"source": [
- "The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the colab runtime**. "
+ "### Distributed Setup\n",
+ "Due to the notebook enviroment, we cannot run [`SPMD`](https://en.wikipedia.org/wiki/SPMD) program here but we can do multiprocessing inside the notebook to mimic the setup. Users should be responsible for setting up their own [`SPMD`](https://en.wikipedia.org/wiki/SPMD) launcher when using Torchrec.\n",
+ "We setup our environment so that torch distributed based communication backend can work."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {
- "id": "_P45pDteRcWj"
+ "id": "4-v17rxkopQw"
},
"outputs": [],
"source": [
- "!sudo cp /usr/local/lib/lib* /usr/lib/"
+ "import os\n",
+ "import copy\n",
+ "import torch\n",
+ "import torchrec\n",
+ "import multiprocess\n",
+ "from torchrec.distributed.types import ShardingEnv\n",
+ "\n",
+ "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
+ "os.environ[\"MASTER_PORT\"] = \"10000\""
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "n5_X2WOAYG3c"
+ "id": "td6bvF_KRbzx"
},
"source": [
- "\\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**"
+ "Below are codes setup for one process at rank `0` and with WORLD_SIZE (number of processes) as `1`.\n",
+ "\n",
+ "In a distributed setup, we will repeat the following steps on each process to set up the distributed environment."
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 4,
"metadata": {
- "id": "8cktNrh8R9rC"
+ "id": "suuMmJs2RbjX"
},
"outputs": [],
"source": [
- "import sys\n",
- "sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']"
+ "import torch.distributed as dist\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "rank = 0\n",
+ "world_size = 1\n",
+ "backend = 'nccl' if 'cuda' in device else 'gloo'\n",
+ "os.environ[\"RANK\"] = f\"{rank}\"\n",
+ "os.environ[\"WORLD_SIZE\"] = f\"{world_size}\""
]
},
{
"cell_type": "markdown",
- "source": [
- "## **Overview**\n",
- "This tutorial will mainly cover the sharding schemes of embedding tables via `EmbeddingPlanner` and `DistributedModelParallel` API and explore the benefits of different sharding schemes for the embedding tables by explicitly configuring them."
- ],
- "metadata": {
- "id": "HWBOrwVSnrNE"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Distributed Setup\n",
- "Due to the notebook enviroment, we cannot run [`SPMD`](https://en.wikipedia.org/wiki/SPMD) program here but we can do multiprocessing inside the notebook to mimic the setup. Users should be responsible for setting up their own [`SPMD`](https://en.wikipedia.org/wiki/SPMD) launcher when using Torchrec. \n",
- "We setup our environment so that torch distributed based communication backend can work."
- ],
"metadata": {
- "id": "udsN6PlUo1zF"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import os\n",
- "import torch\n",
- "import torchrec\n",
- "\n",
- "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
- "os.environ[\"MASTER_PORT\"] = \"29500\""
- ],
- "metadata": {
- "id": "4-v17rxkopQw"
+ "id": "ZdSUWBRxoP8R"
},
- "execution_count": 18,
- "outputs": []
- },
- {
- "cell_type": "markdown",
"source": [
"### Constructing our embedding model\n",
"Here we use TorchRec offering of [`EmbeddingBagCollection`](https://github.com/facebookresearch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L59) to construct our embedding bag model with embedding tables.\n",
"\n",
- "Here, we create an EmbeddingBagCollection (EBC) with four embedding bags. We have two types of tables: large tables and small tables differentiated by their row size difference: 4096 vs 1024. Each table is still represented by 64 dimension embedding. \n",
+ "Here, we create an EmbeddingBagCollection (EBC) with four embedding bags. We have two types of tables: large tables and small tables differentiated by their row size difference: 4096 vs 1024. Each table is still represented by 64 dimension embedding.\n",
"\n",
"We configure the `ParameterConstraints` data structure for the tables, which provides hints for the model parallel API to help decide the sharding and placement strategy for the tables.\n",
- "In TorchRec, we support \n",
+ "In TorchRec, we support\n",
"* `table-wise`: place the entire table on one device;\n",
"* `row-wise`: shard the table evenly by row dimension and place one shard on each device of the communication world;\n",
"* `column-wise`: shard the table evenly by embedding dimension, and place one shard on each device of the communication world;\n",
@@ -177,17 +248,19 @@
"* `data_parallel`: replicate the tables for every device;\n",
"\n",
"Note how we initially allocate the EBC on device \"meta\". This will tell EBC to not allocate memory yet."
- ],
- "metadata": {
- "id": "ZdSUWBRxoP8R"
- }
+ ]
},
{
"cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "e7UQBuG09hbj"
+ },
+ "outputs": [],
"source": [
"from torchrec.distributed.planner.types import ParameterConstraints\n",
"from torchrec.distributed.embedding_types import EmbeddingComputeKernel\n",
- "from torchrec.distributed.types import ShardingType\n",
+ "from torchrec.distributed.types import ShardingType, ShardingPlan\n",
"from typing import Dict\n",
"\n",
"large_table_cnt = 2\n",
@@ -224,56 +297,108 @@
" }\n",
" constraints = {**large_table_constraints, **small_table_constraints}\n",
" return constraints"
- ],
- "metadata": {
- "id": "e7UQBuG09hbj"
- },
- "execution_count": 19,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "U0ktiI29_FVW"
+ },
+ "outputs": [],
"source": [
"ebc = torchrec.EmbeddingBagCollection(\n",
- " device=\"cuda\",\n",
+ " device=torch.device(device),\n",
" tables=large_tables + small_tables\n",
")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UMoH-xpV2S5w",
+ "outputId": "0aa9ca26-2940-479f-c3ee-fef10868258a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "EmbeddingBagCollection(\n",
+ " (embedding_bags): ModuleDict(\n",
+ " (large_table_0): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (large_table_1): EmbeddingBag(4096, 64, mode='sum')\n",
+ " (small_table_0): EmbeddingBag(1024, 64, mode='sum')\n",
+ " (small_table_1): EmbeddingBag(1024, 64, mode='sum')\n",
+ " )\n",
+ ")\n"
+ ]
+ }
],
+ "source": [
+ "print(ebc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
"metadata": {
- "id": "Iz_GZDp_oQ19"
+ "id": "711VBygVHGJ6"
},
- "execution_count": 20,
- "outputs": []
+ "source": [
+ "For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`."
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "1G8aUfmeMA7m"
+ },
"source": [
- "### DistributedModelParallel in multiprocessing\n",
- "Now, we have a single process execution function for mimicking one rank's work during [`SPMD`](https://en.wikipedia.org/wiki/SPMD) execution.\n",
"\n",
- "This code will shard the model collectively with other processes and allocate memories accordingly. It first sets up process groups and do embedding table placement using planner and generate sharded model using `DistributedModelParallel`.\n"
- ],
+ "With data parallel, we will repeat the tables for all devices.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
"metadata": {
"id": "7m0_ssVLFQEH"
- }
+ },
+ "source": [
+ "### DistributedModelParallel in multiprocessing\n",
+ "If you have access for **2 GPUs**, we can work on multi-GPU multi-process sharding. Though due to the issue in \"Spawn\"-started multiprocess, the print may not have outputs on certain devices. But you can check if the assertion is passed.\n",
+ "\n",
+ "we have a single process execution function for mimicking one rank's work during [`SPMD`](https://en.wikipedia.org/wiki/SPMD) execution.\n",
+ "\n",
+ "This code will shard the model collectively with other processes and allocate memories accordingly. It first sets up process groups and do embedding table placement using planner and generate sharded model using `DistributedModelParallel`.\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "PztCaGmLA85u"
+ },
+ "outputs": [],
"source": [
"def single_rank_execution(\n",
" rank: int,\n",
" world_size: int,\n",
" constraints: Dict[str, ParameterConstraints],\n",
" module: torch.nn.Module,\n",
- " backend: str,\n",
+ " backend: str\n",
") -> None:\n",
+ "\n",
" import os\n",
" import torch\n",
" import torch.distributed as dist\n",
" from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n",
" from torchrec.distributed.model_parallel import DistributedModelParallel\n",
" from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\n",
- " from torchrec.distributed.types import ModuleSharder, ShardingEnv\n",
+ " from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan\n",
" from typing import cast\n",
"\n",
" def init_distributed_single_host(\n",
@@ -292,8 +417,9 @@
" torch.cuda.set_device(device)\n",
" else:\n",
" device = torch.device(\"cpu\")\n",
- " topology = Topology(world_size=world_size, compute_device=\"cuda\")\n",
+ " topology = Topology(world_size=world_size, compute_device=device.type)\n",
" pg = init_distributed_single_host(rank, world_size, backend)\n",
+ " # pg = dist.group.WORLD\n",
" planner = EmbeddingShardingPlanner(\n",
" topology=topology,\n",
" constraints=constraints,\n",
@@ -309,36 +435,34 @@
" device=device,\n",
" )\n",
" print(f\"rank:{rank},sharding plan: {plan}\")\n",
+ "\n",
" return sharded_model\n"
- ],
- "metadata": {
- "id": "PztCaGmLA85u"
- },
- "execution_count": 21,
- "outputs": []
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "### Multiprocessing Execution\n",
- "Now let's execute the code in multi-processes representing multiple GPU ranks.\n",
- "\n"
- ],
"metadata": {
"id": "3YvDnV_wz_An"
- }
+ },
+ "source": [
+ "### Multiprocessing Execution\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "arW0Jf6qEl-h"
+ },
+ "outputs": [],
"source": [
- "import multiprocess\n",
- " \n",
"def spmd_sharing_simulation(\n",
" sharding_type: ShardingType = ShardingType.TABLE_WISE,\n",
- " world_size = 2,\n",
+ " world_size = 2, # Change this world size according to the number of GPUs available on your end.\n",
"):\n",
" ctx = multiprocess.get_context(\"spawn\")\n",
" processes = []\n",
+ "\n",
" for rank in range(world_size):\n",
" p = ctx.Process(\n",
" target=single_rank_execution,\n",
@@ -347,214 +471,370 @@
" world_size,\n",
" gen_constraints(sharding_type),\n",
" ebc,\n",
- " \"nccl\"\n",
+ " backend,\n",
" ),\n",
" )\n",
+ " print(f\"start for rank: {rank}\")\n",
" p.start()\n",
" processes.append(p)\n",
"\n",
" for p in processes:\n",
" p.join()\n",
+ " print(f\"exit code: {p.exitcode}\")\n",
" assert 0 == p.exitcode"
- ],
+ ]
+ },
+ {
+ "cell_type": "markdown",
"metadata": {
- "id": "arW0Jf6qEl-h"
+ "id": "iUFvR2AeQQYe"
},
- "execution_count": 22,
- "outputs": []
+ "source": [
+ "Now we can start the multiprocess sharding. There will not be any output in the terminal due to the Spawn method used for generating child processes, but you should see all the assertions will pass and processes exit with code 0."
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "ux63HCxRN5Xr"
+ },
"source": [
"### Table Wise Sharding\n",
"Now let's execute the code in two processes for 2 GPUs. We can see in the plan print that how our tables are sharded across GPUs. Each node will have one large table and one small which shows our planner tries for load balance for the embedding tables. Table-wise is the de-factor go-to sharding schemes for many small-medium size tables for load balancing over the devices."
- ],
- "metadata": {
- "id": "31UWMaymj7Pu"
- }
+ ]
},
{
"cell_type": "code",
- "source": [
- "spmd_sharing_simulation(ShardingType.TABLE_WISE)"
- ],
+ "execution_count": 10,
"metadata": {
"colab": {
- "base_uri": "https://localhost:8080/"
+ "base_uri": "https://localhost:8080/",
+ "height": 263
},
- "id": "Yb4v1HA3IJzU",
- "outputId": "b8f08b10-eb85-48f3-8705-b67efd4eba2c"
+ "id": "gZfZatL5QJqc",
+ "outputId": "ed6c86b8-4e8d-418d-e5be-089a12e98f89"
},
- "execution_count": 23,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
- "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}\n",
- "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}\n"
+ "start for rank: 0\n",
+ "start for rank: 1\n",
+ "rank:0,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "large_table_0 | table_wise | fused | [0] \n",
+ "large_table_1 | table_wise | fused | [1] \n",
+ "small_table_0 | table_wise | fused | [0] \n",
+ "small_table_1 | table_wise | fused | [1] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [0, 0] | [4096, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [1024, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [0, 0] | [1024, 64] | rank:1/cuda:1\n",
+ "rank:1,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "large_table_0 | table_wise | fused | [0] \n",
+ "large_table_1 | table_wise | fused | [1] \n",
+ "small_table_0 | table_wise | fused | [0] \n",
+ "small_table_1 | table_wise | fused | [1] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [0, 0] | [4096, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [1024, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [0, 0] | [1024, 64] | rank:1/cuda:1\n"
]
},
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n",
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n"
+ "[rank0]:[W523 02:45:21.010274288 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "exit code: 0\n",
+ "exit code: 0\n"
]
}
+ ],
+ "source": [
+ "spmd_sharing_simulation(ShardingType.TABLE_WISE)"
]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "kfiXoz-2NBj4"
+ },
"source": [
"### Explore other sharding modes\n",
- "We have initially explored what table-wise sharding would look like and how it balances the tables placement. Now we explore sharding modes with finer focus on load balance: row-wise. Row-wise is specifically addressing large tables which a single device cannot hold due to the memory size increase from large embedding row numbers. It can address the placement of the super large tables in your models. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by row dimension to be distributed onto two GPUs.\n"
- ],
- "metadata": {
- "id": "5HkwxEwm4O8u"
- }
+ "We have initially explored what table-wise sharding would look like and how it balances the tables placement. Now we explore sharding modes with finer focus on load balance: row-wise.\n",
+ "\n",
+ "Row-wise is specifically addressing large tables which a single device cannot hold due to the memory size increase from large embedding row numbers. It can address the placement of the super large tables in your models. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by row dimension to be distributed onto two GPUs.\n",
+ "\n",
+ "If you are on a CPU, the row-wise sharding will not be allowed.\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "spmd_sharing_simulation(ShardingType.ROW_WISE)"
- ],
+ "execution_count": 11,
"metadata": {
- "id": "pGBgReGx5VrB",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "6e22a2f0-7373-4dcc-ee69-67f3e95d78a7"
+ "id": "853VYV_0P3us"
},
- "execution_count": 24,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
- "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}\n",
- "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}\n"
+ "start for rank: 0\n",
+ "start for rank: 1\n",
+ "rank:0,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks \n",
+ "------------- | ------------- | -------------- | ------\n",
+ "large_table_0 | row_wise | fused | [0, 1]\n",
+ "large_table_1 | row_wise | fused | [0, 1]\n",
+ "small_table_0 | row_wise | fused | [0, 1]\n",
+ "small_table_1 | row_wise | fused | [0, 1]\n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [2048, 64] | rank:0/cuda:0\n",
+ "large_table_0 | [2048, 0] | [2048, 64] | rank:1/cuda:1\n",
+ "large_table_1 | [0, 0] | [2048, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [2048, 0] | [2048, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [512, 64] | rank:0/cuda:0\n",
+ "small_table_0 | [512, 0] | [512, 64] | rank:1/cuda:1\n",
+ "small_table_1 | [0, 0] | [512, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [512, 0] | [512, 64] | rank:1/cuda:1\n",
+ "rank:1,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks \n",
+ "------------- | ------------- | -------------- | ------\n",
+ "large_table_0 | row_wise | fused | [0, 1]\n",
+ "large_table_1 | row_wise | fused | [0, 1]\n",
+ "small_table_0 | row_wise | fused | [0, 1]\n",
+ "small_table_1 | row_wise | fused | [0, 1]\n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [2048, 64] | rank:0/cuda:0\n",
+ "large_table_0 | [2048, 0] | [2048, 64] | rank:1/cuda:1\n",
+ "large_table_1 | [0, 0] | [2048, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [2048, 0] | [2048, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [512, 64] | rank:0/cuda:0\n",
+ "small_table_0 | [512, 0] | [512, 64] | rank:1/cuda:1\n",
+ "small_table_1 | [0, 0] | [512, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [512, 0] | [512, 64] | rank:1/cuda:1\n"
]
},
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[rank0]:[W523 02:45:29.116195809 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
"text": [
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n",
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n"
+ "exit code: 0\n",
+ "exit code: 0\n"
]
}
+ ],
+ "source": [
+ "spmd_sharing_simulation(ShardingType.ROW_WISE)"
]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "TyJeUphwP3ut"
+ },
"source": [
"Column-wise on the other hand, address the load imbalance problems for tables with large embedding dimensions. We will split the table vertically. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by embedding dimension to be distributed onto two GPUs.\n"
- ],
- "metadata": {
- "id": "mqnInw_uEjjY"
- }
+ ]
},
{
"cell_type": "code",
- "source": [
- "spmd_sharing_simulation(ShardingType.COLUMN_WISE)"
- ],
+ "execution_count": 12,
"metadata": {
- "id": "DWTyuV9I5afU",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "daaa95cd-f653-47fe-809f-5d1d63cc05d7"
+ "id": "vJ1lurhrP3ut"
},
- "execution_count": 25,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
- "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}\n",
- "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}\n"
+ "start for rank: 0\n",
+ "start for rank: 1\n",
+ "rank:0,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "large_table_0 | column_wise | fused | [0] \n",
+ "large_table_1 | column_wise | fused | [1] \n",
+ "small_table_0 | column_wise | fused | [0] \n",
+ "small_table_1 | column_wise | fused | [1] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [0, 0] | [4096, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [1024, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [0, 0] | [1024, 64] | rank:1/cuda:1\n",
+ "rank:1,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks\n",
+ "------------- | ------------- | -------------- | -----\n",
+ "large_table_0 | column_wise | fused | [0] \n",
+ "large_table_1 | column_wise | fused | [1] \n",
+ "small_table_0 | column_wise | fused | [0] \n",
+ "small_table_1 | column_wise | fused | [1] \n",
+ "\n",
+ " param | shard offsets | shard sizes | placement \n",
+ "------------- | ------------- | ----------- | -------------\n",
+ "large_table_0 | [0, 0] | [4096, 64] | rank:0/cuda:0\n",
+ "large_table_1 | [0, 0] | [4096, 64] | rank:1/cuda:1\n",
+ "small_table_0 | [0, 0] | [1024, 64] | rank:0/cuda:0\n",
+ "small_table_1 | [0, 0] | [1024, 64] | rank:1/cuda:1\n"
]
},
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[rank0]:[W523 02:45:37.247471296 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
"text": [
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n",
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n"
+ "exit code: 0\n",
+ "exit code: 0\n"
]
}
+ ],
+ "source": [
+ "spmd_sharing_simulation(ShardingType.COLUMN_WISE)"
]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "0kapPlSFP3ut"
+ },
"source": [
"For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`."
- ],
- "metadata": {
- "id": "711VBygVHGJ6"
- }
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "7HOhFal4P3ut"
+ },
"source": [
"\n",
"With data parallel, we will repeat the tables for all devices.\n"
- ],
- "metadata": {
- "id": "1G8aUfmeMA7m"
- }
+ ]
},
{
"cell_type": "code",
- "source": [
- "spmd_sharing_simulation(ShardingType.DATA_PARALLEL)"
- ],
+ "execution_count": 13,
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "WFk-QLlRL-ST",
- "outputId": "662a6d6e-cb1b-440d-ff1b-4619076117a3"
+ "id": "ePzMsOp-P3uu"
},
- "execution_count": 26,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
- "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}\n",
- "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}\n"
+ "start for rank: 0\n",
+ "start for rank: 1\n"
]
},
{
+ "name": "stderr",
"output_type": "stream",
+ "text": [
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n",
+ "Sharding Type is data_parallel, caching params will be ignored\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rank:0,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks \n",
+ "------------- | ------------- | -------------- | ------\n",
+ "large_table_0 | data_parallel | dense | [0, 1]\n",
+ "large_table_1 | data_parallel | dense | [0, 1]\n",
+ "small_table_0 | data_parallel | dense | [0, 1]\n",
+ "small_table_1 | data_parallel | dense | [0, 1]\n",
+ "\n",
+ "\n",
+ "\n",
+ "rank:1,sharding plan: module: \n",
+ "\n",
+ " param | sharding type | compute kernel | ranks \n",
+ "------------- | ------------- | -------------- | ------\n",
+ "large_table_0 | data_parallel | dense | [0, 1]\n",
+ "large_table_1 | data_parallel | dense | [0, 1]\n",
+ "small_table_0 | data_parallel | dense | [0, 1]\n",
+ "small_table_1 | data_parallel | dense | [0, 1]\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
"name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[rank0]:[W523 02:45:45.339723036 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
"text": [
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n",
- "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
- " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n"
+ "exit code: 0\n",
+ "exit code: 0\n"
]
}
+ ],
+ "source": [
+ "spmd_sharing_simulation(ShardingType.DATA_PARALLEL)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
- "background_execution": "on",
- "collapsed_sections": [],
- "machine_shape": "hm",
- "name": "Torchrec Sharding Introduction.ipynb",
+ "gpuType": "T4",
"provenance": []
},
"kernelspec": {
@@ -562,7 +842,16 @@
"name": "python3"
},
"language_info": {
- "name": "python"
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/examples/torcharrow/README.md b/examples/torcharrow/README.md
deleted file mode 100644
index cc2613a38..000000000
--- a/examples/torcharrow/README.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Description
-
-This shows a prototype of integrating a TorchRec based training loop utilizing TorchArrow's on-the-fly preprocessing. The main motivation is to show the utilization of TorchArrow's specialized domain UDFs. Here we use `bucketize`, `firstx`, as well as `sigrid_hash` to do some last-mile preprocessing over the criteo dataset in parquet format. More recommendation domain functions can be found at [torcharrow.functional Doc](https://pytorch.org/torcharrow/beta/functional.html#recommendation-operations).
-
-These three UDFs are extensively used in Meta's RecSys preprocessing stack. Notably, these UDFs can be used to easily adjust the proprocessing script to any model changes. For example, if we wish to change the size of our embedding tables, without sigrid_hash, we would need to rerun a bulk offline preproc to ensure that all indicies are within bounds. Bucketize lets us easily convert dense features into sparse features, with flexibility of what the bucket borders are. firstx lets us easily prune sparse ids (note, that this doesn't provide any functional features, but is in the preproc script as demonstration).
-
-
-## Installations and Usage
-
-Download the criteo tsv files (see the README in the main DLRM example). Use the nvtabular script (in torchrec/datasets/scripts/nvt/) to convert the TSV files to parquet.
-
-To start, install torcharrow-nightly and torchdata:
-```
-pip install --pre torcharrow -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
-pip install torchdata
-```
-You can also build TorchArrow from source, following https://github.com/pytorch/torcharrow
-
-Usage
-
-```
-torchx run -s local_cwd dist.ddp -j 1x4 --script examples/torcharrow/run.py -- --parquet_directory /home/criteo_parquet
-```
-
-The preprocessing logic is in ```dataloader.py```
-
-## Extentions/Future work
-
-* We will eventually integrate with the up and coming DataLoader2, which will allow us to utilize a prebuilt solution to collate our dataframe batches to dense tensors, or TorchRec's KeyedJaggedTensors (rather than doing this by hand).
-* Building an easier solution/more performant to convert parquet -> IterableDataPipe[torcharrow.DataFrame] (aka ArrowDataPipe). Also currently batch sizes are not available.
-* Some functional abilities are not yet available (such as make_named_row, etc).
-* Support collation/conversion for ArrowDataPipe
-* More RecSys UDFs to come! Please let us know if you have any suggestions.
diff --git a/examples/torcharrow/dataloader.py b/examples/torcharrow/dataloader.py
deleted file mode 100644
index 3cca4666d..000000000
--- a/examples/torcharrow/dataloader.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torcharrow as ta
-import torcharrow.dtypes as dt
-import torcharrow.pytorch as tap
-from torch.utils.data import DataLoader
-from torcharrow import functional
-from torchdata.datapipes.iter import FileLister
-from torchrec.datasets.criteo import (
- DEFAULT_CAT_NAMES,
- DEFAULT_INT_NAMES,
- DEFAULT_LABEL_NAME,
-)
-
-from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
-
-
-class _JaggedTensorConversion(tap.TensorConversion):
- # pyre-fixme[14]: `to_tensor` overrides method defined in `TensorConversion`
- # inconsistently.
- def to_tensor(self, df: ta.DataFrame):
-
- kjt_keys = df.columns
- kjt_values = []
- kjt_lengths = []
- for row in df:
- for idx, _column in enumerate(df.columns):
- value = row[idx]
- kjt_values.extend(value)
- kjt_lengths.append(len(value))
- kjt = KeyedJaggedTensor.from_lengths_sync(
- keys=kjt_keys,
- values=torch.tensor(kjt_values),
- lengths=torch.tensor(kjt_lengths),
- )
- return kjt
-
-
-class _Scalar(tap.TensorConversion):
- def to_tensor(self, df: ta.DataFrame):
- labels = torch.tensor(df)
- return labels
-
-
-def get_dataloader(
- parquet_directory, world_size, rank, num_embeddings=4096, salt=0, batch_size=16
-):
- source_dp = FileLister(parquet_directory, masks="*.parquet")
- # TODO support batch_size for load_parquet_as_df.
- # TODO use OSSArrowDataPipe once it is ready
- parquet_df_dp = source_dp.load_parquet_as_df()
-
- def preproc(df, max_idx=num_embeddings, salt=salt):
-
- for feature_name in DEFAULT_INT_NAMES:
- df[feature_name] = df[feature_name].fill_null(0)
- for feature_name in DEFAULT_CAT_NAMES:
- df[feature_name] = df[feature_name].fill_null(0)
- df[feature_name] = df[feature_name].cast(dt.int64)
-
- # construct a sprase index from a dense one
- df["bucketize_int_0"] = functional.bucketize(df["int_0"], [0.5, 1.0, 1.5]).cast(
- dt.int64
- )
-
- # flatten several columns into one
- df["dense_features"] = ta.dataframe(
- {int_name: df[int_name] for int_name in DEFAULT_INT_NAMES}
- )
-
- df["dense_features"] = (df["dense_features"] + 3).log()
-
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]:
- # hash our embedding index into our embedding tables
- df[cat_name] = functional.sigrid_hash(df[cat_name], salt, max_idx)
- df[cat_name] = functional.array_constructor(df[cat_name])
- df[cat_name] = functional.firstx(df[cat_name], 1)
-
- df["sparse_features"] = ta.dataframe(
- {
- cat_name: df[cat_name]
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]
- }
- )
-
- df = df[["dense_features", "sparse_features", DEFAULT_LABEL_NAME]]
-
- return df
-
- parquet_df_dp = parquet_df_dp.map(preproc).sharding_filter()
- parquet_df_dp.apply_sharding(world_size, rank)
-
- def criteo_collate(df):
- dense_features, kjt, labels = df.to_tensor(
- {
- "dense_features": tap.rec.Dense(batch_first=True),
- "sparse_features": _JaggedTensorConversion(),
- "label": _Scalar(),
- }
- )
-
- return dense_features, kjt, labels
-
- return DataLoader(
- parquet_df_dp,
- batch_size=None,
- collate_fn=criteo_collate,
- drop_last=False,
- pin_memory=True,
- )
diff --git a/examples/torcharrow/requirements.txt b/examples/torcharrow/requirements.txt
deleted file mode 100644
index 3e3c4c950..000000000
--- a/examples/torcharrow/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-click
-torcharrow
-torchdata
diff --git a/examples/torcharrow/run.py b/examples/torcharrow/run.py
deleted file mode 100644
index 1a3dca47d..000000000
--- a/examples/torcharrow/run.py
+++ /dev/null
@@ -1,126 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-
-import click
-import dataloader as torcharrow_dataloader
-import torch
-import torch.distributed as dist
-from fbgemm_gpu.split_embedding_configs import EmbOptimType
-from torch.distributed.elastic.multiprocessing.errors import record
-from torchrec import EmbeddingBagCollection
-from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, INT_FEATURE_COUNT
-from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
-from torchrec.distributed.model_parallel import DistributedModelParallel
-from torchrec.models.dlrm import DLRM
-from torchrec.modules.embedding_configs import EmbeddingBagConfig
-from torchrec.optim.keyed import KeyedOptimizerWrapper
-
-
-@record
-@click.command()
-@click.option("--batch_size", default=256)
-@click.option("--num_embeddings", default=2048)
-@click.option("--sigrid_hash_salt", default=0)
-@click.option("--parquet_directory", default="/data/criteo_preproc")
-def main(
- batch_size,
- num_embeddings,
- sigrid_hash_salt,
- parquet_directory,
-) -> None:
- rank = int(os.environ["LOCAL_RANK"])
- if torch.cuda.is_available():
- device = torch.device(f"cuda:{rank}")
- backend = "nccl"
- torch.cuda.set_device(device)
- else:
- device = torch.device("cpu")
- backend = "gloo"
- print(
- "\033[92m"
- + f"WARNING: Running in CPU mode. cuda availablility {torch.cuda.is_available()}."
- )
-
- dist.init_process_group(backend=backend)
-
- world_size = dist.get_world_size()
-
- dataloader = torcharrow_dataloader.get_dataloader(
- parquet_directory,
- world_size,
- rank,
- batch_size=batch_size,
- num_embeddings=num_embeddings,
- salt=sigrid_hash_salt,
- )
- it = iter(dataloader)
-
- model = DLRM(
- embedding_bag_collection=EmbeddingBagCollection(
- tables=[
- EmbeddingBagConfig(
- name=f"table_{cat_name}",
- embedding_dim=64,
- num_embeddings=num_embeddings,
- feature_names=[cat_name],
- )
- for cat_name in DEFAULT_CAT_NAMES + ["bucketize_int_0"]
- ],
- device=torch.device("meta"),
- ),
- dense_in_features=INT_FEATURE_COUNT,
- dense_arch_layer_sizes=[64],
- over_arch_layer_sizes=[32, 1],
- dense_device=device,
- )
-
- fused_params = {
- "learning_rate": 0.02,
- "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
- }
-
- sharded_model = DistributedModelParallel(
- module=model,
- device=device,
- sharders=[
- EmbeddingBagCollectionSharder(fused_params=fused_params),
- ],
- )
-
- optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
- lambda params: torch.optim.SGD(params, lr=0.01),
- )
-
- loss_fn = torch.nn.BCEWithLogitsLoss()
-
- print_example = dist.get_rank() == 0
- for (dense_features, kjt, labels) in it:
- if print_example:
- print("Example dense_features", dense_features)
- print("Example KJT input", kjt)
- print_example = False
-
- dense_features = dense_features.to(device)
- kjt = kjt.to(device)
- labels = labels.to(device)
-
- optimizer.zero_grad()
-
- preds = sharded_model(dense_features, kjt)
- loss = loss_fn(preds.squeeze(), labels.squeeze())
- loss.sum().backward()
-
- optimizer.step()
-
- print("\033[92m" + "DLRM run with torcharrow last-mile preprocessing finished!")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/transfer_learning/train_from_pretrained_embedding.py b/examples/transfer_learning/train_from_pretrained_embedding.py
index 6d659f67d..ac7db53b4 100644
--- a/examples/transfer_learning/train_from_pretrained_embedding.py
+++ b/examples/transfer_learning/train_from_pretrained_embedding.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import copyreg
import io
import os
@@ -222,7 +224,7 @@ def main() -> None:
dist.barrier()
-if __name__ == "__main__":
+def invoke_main() -> None:
lc = pet.LaunchConfig(
min_nodes=1,
max_nodes=1,
@@ -234,3 +236,7 @@ def main() -> None:
monitor_interval=1,
)
pet.elastic_launch(lc, entrypoint=main)()
+
+
+if __name__ == "__main__":
+ invoke_main() # pragma: no cover
diff --git a/install-requirements.txt b/install-requirements.txt
new file mode 100644
index 000000000..ed3c6aced
--- /dev/null
+++ b/install-requirements.txt
@@ -0,0 +1,6 @@
+fbgemm-gpu
+tensordict
+torchmetrics==1.0.3
+tqdm
+pyre-extensions
+iopath
diff --git a/requirements.txt b/requirements.txt
index daa72b622..6239d0d90 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,43 +1,22 @@
-arrow
-attrs
-certifi
-charset-normalizer
+black
+click
cmake
-Cython
-distro
-docker
-docstring-parser
-fbgemm-gpu-nightly
-filelock
-fsspec
-hypothesis
-idna
+fbgemm-gpu
+hypothesis==6.70.1
+importlib-metadata
iopath
-Jinja2
-MarkupSafe
-mypy-extensions
-ninja
numpy
-packaging
pandas
-portalocker
-pyarrow
-pyDeprecate
-pyparsing
-pyre-extensions==0.0.27
-python-dateutil
-pytz
-PyYAML
-requests
+pyre-extensions
scikit-build
-six
-sortedcontainers
-tabulate
-torchmetrics
+tensordict
+torchmetrics==1.0.3
torchx
tqdm
-typing-inspect
-typing_extensions
-urllib3
usort
-websocket-client
+parameterized
+PyYAML
+
+# for tests
+# https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3
+expecttest
diff --git a/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md
new file mode 100644
index 000000000..09e9069fb
--- /dev/null
+++ b/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md
@@ -0,0 +1,146 @@
+# RFC: Flexible Collision Free Embedding Table
+
+| Status | Published |
+| :---- | :---- |
+| Author(s) | Emma Lin, Joe Wang, Kaustubh Vartak, Dennis van der Staay, Huanyu He|
+| Updated | 04-13-2025 |
+
+## Motivation
+
+PyTorch and FBGemm utilize fixed-size continuous embedding memory to handle sparse features, employing a uniform hash function to map raw IDs to a limited index space. This approach has resulted in a high rate of hash collisions and inefficient storage.
+
+The Zero Collision Hash (ZCH) technique allows models to train individual IDs uniquely, leading to notable enhancements in model freshness and user engagement. When properly configured, we've seen improvements in freshness late stage ranking and early stage ranking models. However, the remapping-based ZCH solutions have presented several scalability challenges.
+
+## Objective
+
+This RFC proposes a new embedding table format that natively supports collision-free features, enhancing the scalability and usability of embedding tables.
+
+The approach involves utilizing an extremely large hash size (e.g., 2^63) to map raw IDs to a vast ID space, significantly reducing the likelihood of collisions during hashing.
+
+Instead of mapping IDs to a fixed table size (e.g., 1 billion rows), this format reserves memory spaces for each ID, effectively eliminating collisions and achieving native collision-free results.
+
+Notably, this design eliminates the need for a remapper to find available slots for colliding IDs, streamlining the process and improving overall efficiency, embedding scalability and usability.
+
+## Design Proposal
+
+### Bucket Aware Sharding and Resharding Algorithm
+
+To address the dynamic nature of sparse IDs and their distribution, we propose introducing a bucket concept. We will provide a large default bucket number configuration and an extremely large table size. The mapping from ID to bucket ID can be done in two ways:
+
+* Interleave-based: bucket\_id \= hash\_id % total\_bucket\_number
+ This approach is similar to the sharding solution used in MPZCH, allowing for seamless migration without requiring ID resharding.
+* Chunk-based:
+ * bucket\_size \= table\_size / total\_bucket\_number,
+ * bucket\_id \= id / bucket\_size
+
+
+Both options will be configurable.
+
+After sharding IDs into buckets, we will distribute the buckets sequentially across trainers. For example, with 1000 buckets and 100 trainers, each trainer would own 10 consecutive buckets.
+
+T1: b0-b9
+
+T2: b10-b19
+
+...
+
+When resharding is necessary, we will move buckets around instead of individual rows. For instance, reducing the number of trainers from 100 to 50 would result in each trainer owning 20 consecutive buckets.
+
+T1: b0-b19
+
+T2: b20-b39
+
+...
+
+The row count within each bucket can vary from 0 to the maximum bucket size, depending on the ID distribution. However, using a larger bucket number should lead to a more even distribution of IDs.
+
+#### Benefit
+
+The bucket number remains unchanged when scaling the model or adjusting the number of trainers, making it easier to move buckets around without introducing additional overhead to reshard every ID's new location.
+
+Resharding every ID can be an expensive operation, especially for large tables (e.g., over 1 billion rows).
+
+### Bucketized Torchrec Sharding and Input Distribution
+
+Based on the proposed sharding definition, the TorchRec sharder needs to be aware of the bucket configuration from the embedding table.
+
+Input distribution needs to take into account the bucket configuration, and then distribute input to the corresponding trainer.
+
+Here is the code [reference](https://github.com/pytorch/torchrec/blob/f36d26db4432bd7335f6df9e7e75d8643b7ffb04/torchrec/distributed/sharding/rw_sequence_sharding.py#L129C16-L129C36).
+
+### FBGemm Operators Optimization for Collision Free EmbeddingTable
+
+FBGEMM\_GPU (FBGEMM GPU Kernels Library) is highly optimized for fixed sized tables, with continuous memory space, including in HBM, UVM or CPU memory.
+However, when we apply collision free idea, there are several assumptions of FBGEMM are broken:
+
+* Table size is not fixed. It could grow over training iterations or shrink after eviction.
+* Embedding lookup input is not embedding offset anymore, so we need to maintain an explicit mapping from input to the embedding value.
+* Table size could exceed memory limitation, but actual trained id size is finite, so we cannot preserve memory based on table configuration.
+
+We’re looking for an optimized K/V FBGemm version to support flexible memory management.
+
+#### Training Operator (from [Joe Wang](mailto:wangj@meta.com))
+
+* Optimized CPU memory management with K/V format
+ * Reduce memory fragmentation
+ * efficient memory utilization
+ * Fast lookup performance
+ * Flexible eviction policy
+* Collision free LXU cache to avoid extra memory copy from CPU to UVM and UVM memory read during embedding lookup.
+ * The current LXU cache used by FBGemm could cause id collisions. When collision happens, prefetch won’t be able to load embedding value to HBM, which will fallback to UVM memory read during embedding lookup. This can impact training QPS in two ways:
+ * Introduce one extra CPU memory copy, since data needs to be copied from CPU to UVM, since the CPU embedding data in k/v format might not be accessible from the GPU card.
+ * Introduced H2D data copy in embedding lookup.
+
+[Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py) is the code reference of a k/v format SSD offloading operator, and provides a backend interface to hook up other k/v store implementations.
+We propose to implement a new k/v store backend, to decouple SSD and rocksDB dependency, but the SSD backend operator can be used for extra large embeddings which do not fit into host memory.
+
+#### Inference Operator
+
+On top of training operators functionality, the inference operator needs to support dequantization from nbit int value after embedding is queried out from the embedding store. We’d like to have a fast inference operator with additional requirements:
+
+* Optimized CPU memory management with k/v format
+* Collision free LXU cache
+* Keep the fast nbit Int data format support, with pooling, dequantization features.
+* Support decoupled large tensor loading and reset, to allow model state in-place update.
+ [Here](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py) is the current inference operator which only supports offset based access for now.
+
+### Enhancing TorchRec's Prefetch Pipeline for Synchronized Training
+
+TorchRec offers multiple training [pipelines](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py) that help overlap communication and computation, reducing embedding lookup latency and improving training QPS. Specifically, PrefetchTrainPipelineSparseDist supports synchronized training, while TrainPipelineSemiSync supports asynchronous training.
+
+Our goal is to further enhance the prefetch pipeline to enable prefetching multiple training data batches ahead of time while maintaining synchronized training and avoiding step misalignment.
+
+**Design Principles:**
+
+* Zero Collision Cache: Enable zero collision cache in GPU to cache embeddings for multiple batches without collision-induced cache misses.
+* Forward Pass: Perform embedding lookups only in HBM during forward passes to improve performance.
+* Backward Pass: Update embeddings in HBM synchronously during backward passes to ensure all embedding lookup results are up-to-date.
+* Asynchronous UVM Embedding Update: Update UVM embeddings asynchronously after embedding updates in HBM.
+
+**End Goal:**
+
+Achieve on-par performance with GPU HBM-based training while scaling up sparse embedding tables in CPU memory.
+
+### Warm Start and Transfer Learning with Collision-Free Embedding Tables
+
+Warm start, or transfer learning, is a widely used technique in industry to facilitate model iteration while maintaining on-par topline metrics. However, the introduction of collision-free embedding tables poses challenges to this process.
+
+With the ability to increase table size and feature hash size to \~2^60, collision-free embedding tables offer improved efficiency. However, since id hash size is changed and sharding solution is different, when resuming training from a non-zero collision table to a zero-collision table, the redistribution of IDs across trainers becomes computationally expensive.
+
+#### Solution: Backfilling Embedding Values
+
+To address this challenge, we propose the following solution:
+
+* Create Two Embedding Tables: One table is copied from the old model, and the other is the new table.
+
+* Freeze Old Embedding Table: The old embedding table is set to read-only mode in the new model.
+
+* Training Forward Loop: During the forward pass, if an embedding value is not found in the new table, the underlying embedding lookup operator searches the old embedding table for a pre-trained value.
+
+ * This requires an additional all-to-all call using TorchRec to retrieve the old embedding value.
+
+ * We need to leverage the prefetch process to hide the extra latency.
+
+* Stop Backfilling Process: Once the new table is sufficiently populated, the backfilling process can be stopped.
+
+This approach enables efficient warm start and transfer learning with collision-free embedding tables, reducing the computational overhead associated with ID redistribution.
diff --git a/rfc/RFC-0002-assets/KV_storage_extension_Design.md b/rfc/RFC-0002-assets/KV_storage_extension_Design.md
new file mode 100644
index 000000000..d12444b16
--- /dev/null
+++ b/rfc/RFC-0002-assets/KV_storage_extension_Design.md
@@ -0,0 +1,94 @@
+# KV storage extension Design for TBE; SSD and PS Examples
+
+Sarunya Pumma, Emma Lin, Ehsan K. Ardestani, Joe Wang
+
+# Design Principles
+
+1) Extend current TBE: There is considerable effort and expertise which has gone toward enabling performance optimized TBE for accessing HBM as well as host DRAM. We want to leverage such capabilities, and extend on top of TBE.
+2) Abstract out the details of the backend memory: The memory we use could be SSD, Remote memory tiers through back end, or remote memory through front end. We want to enable all such capabilities, without adding backend specific logic to the TBE code.
+
+# High Level Design
+
+Considering design principles listed above, we have opted on a Key-Value API. TBE will offer a software managed cache in HBM, as we do when leveraging the host side memory. However, unlike extension to host side memory where we leverage UVA to prefetch, currently we opt for a copy based API, mandated by design rule \#2, to separate the implementation of the backend KV store, and TBE. It is possible that in future we might adopt a UVA based command queue approach interface with the KV-store, if the cost of the copy based semantics proves prohibitive.
+[![image1]](./kv_tbe_training_high_level.png)
+
+Figure 1: High level architecture of TBE KV Store based extension. The blocks with orange line are implemented by TBE.
+
+The prefetch access to the KV store happens through [*get\_cuda*](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py#L1455) and [*set\_cuda*](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py#L1012) methods (APIs), which in turn use *cudaStreamAddCallback to maintain the execution order and call on the backend implementation of get and set.*
+
+We only allow access to the KV Store by TBE through *prefetch* API. The TBE forward and backward methods only operate on HBM or DRAM through UVA.
+
+The Auxiliary buffers indicated in Figure 1 provide a scratch pad to stage the data from KV store to GPU, and from GPU to KV Store (eviction).
+
+# Implementation
+
+We do expect a training pipeline similar to EMO-DRAM to allow for overlapping prefetch(i+1) with train(i), where (i) denotes a training iteration number. This requires some extra work on the train pipeline to enable prefetch pipeline on top of SSD pipeline. The high-level workflow of pipeline prefetching is shown in the figure below.
+
+[![image2]](./kv_tbe_pipeline_prefetching.png)
+
+# Handling Conflict Miss
+
+We introduced a UVA scratch pad buffer (which can be accessed by both device and host) for storing L1 conflict missed rows (the missing rows that cannot fit in L1 due to a lack of enough associativity). Moreover, we use the scratch buffer for staging data when fetching it from the backend storage (SSD or L2) into L1 for the rows that are missed but can be inserted into L1. The TBE forward and backward kernels access L1 if rows are in cache, otherwise access the scratch pad. This would be similar to handling conflict misses in EMO-DRAM (UVM\_Caching).
+
+# Ensuring Coherence
+
+The training stream only accesses HBM and scratch pad (UVA) buffers. The access to the KV Store only happens during prefetch. This simplifies maintaining coherence among the training stream and prefetch stream.
+
+Similar to EMO+DRAM:
+
+1. We do enforce ordering between prefetch(i+1) and TBE backward(i), as the prefetch(i+1) can occur after prefetch(i) and needs to finish before the TBE backward(i). As a result, we have to ensure Read After Write (RAW) is maintained for *Batch i* rows that are present in the scratch pad (UVA) buffer (due to conflict), and the *Batch i+1* rows that are prefetched. This means the L1 cache updates due to prefetch need to be visible to TBE backward.
+2. Currently, a scratch pad only holds data for each iteration. It is allocated/populated during prefetch and evicted/deallocated after the backward pass of TBE is complete. To avoid evicting and re-fetching the rows that are used by two consecutive iterations, TBE looks up the SP(i) (the scratch pad of iteration i) during prefetch(i+1). It ensures that overlapping rows are moved to SP(i+1) safely by updating the locations (pointers) of rows in iteration i.
+3. To guarantee that L1 cache lines are not being prematurely evicted, we lock them when they are being used (as soon as they are being inserted) and unlock them after their usage is complete (after each TBE backward iteration). When a cache line is locked, it cannot be evicted.
+4. Only support prefetch distance \= 1
+5. The user must ensure that prefetch(i) is invoked before forward(i)
+
+The detailed prefetch workflow is demonstrated in the figure below.
+
+
+
+TBE will ensure a unified UVA buffer across prefetch and eviction flows.
+
+# Concurrency of Prefetch and Training stages
+
+None of the functions in prefetch stream are blocking the GPU training stream. This includes the memory copy and the get/set function calls to access the row in the KV store.
+
+- The ssd\_cache\_populate does consume SM hence impacting the performance of kernels on the training stream. However, we expect these kernels to be quick
+- The D2H to transfer the inserted\_indices (keys) to the host side for get\_cuda call is nonblocking. The ordering between the host call and the kernels on the prefetch stream is ensured through cudaStreamAddCallBack semantic which ensures the call back to the backend implementation is made once all the previously submitted kernels in the stream are finished.
+- get\_cuda will delegate the fetching of rows to the backend get function through host call back (running on another CPU thread), and does not block GPU. We expect the get to be the longest function, and if not fully overlapped with training (dense FWD and BWD for example in sync training), will impact the end to end performance.
+- set\_cuda will delegate eviction of the rows to the backend set function through host call back, and does not block GPU. Moreover, the memory copy of the evicted\_indices is done on the evict stream. Thus, it does not block the training stream
+- The masked\_index\_put/masked\_index\_select functions need a limited number of SMs to ensure it can saturate the GPU to CPU bandwidth as it moves embedding rows from scratch pad buffers to the emb L1 cache.
+
+# SSD backend
+
+One of the backends to extend the memory beyond host DRAM is the SSDs our AI HW offer. Using SSD introduces the range of constraints:
+
+- Access to SSDs are not byte level. Many SSDs have designs around 4K access granularity. This could result in read or write amplifications when the access data size is smaller than 4K.
+- SSD bandwidth is 1-2 orders of magnitude lower than host DRAM, which itself is typically 1-2 orders of magnitude slower than HBM.
+- SSDs wear as data is written into them. Typically, SSDs have 3-5 pDWPD (physical write per day) guarantee for 3-5 years. This means a 4TB SSD with 5 pDWPD in 5 years, can sustain daily writes of 20TB per day for 5 years.
+
+
+
+
+For KV Store implementation, we opted for leveraging RocksDB instead of building the stack from scratch. The main features of RocksDB for our decision are:
+
+- RocksDB is a KV store
+- RocksDB addresses write amplification by consolidating writes in a contiguous space, and updating the mapping
+- RocksDB provides in memory structures to accelerate access to SSD.
+
+[RocksDB](https://github.com/facebook/rocksdb/wiki/Basic-Operations) provides the basic KV store APIs for read (Get, MultiGet) and write(Put, WriteBatch), which is leveraged by the [EmbeddingRocksDB](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h#L76) class.
+
+## Supported for raw ID indexing
+
+With KV API, we can support using raw IDs when the embeddings are mapped to the KV Store. This means the solution can provide the benefit of Zero Collision Hashing (ZCH) like approaches natively.
+
+# Persisting
+
+One main challenge to be addressed when tensors are stored in SSD is the fact that we cannot simply pass a tensor in response to *state\_dict* API. Given the KV store backend, we need to materialize the tensor in memory chunk by chunk.
+
+## Partial Tensor copy Support
+
+Current proposal is to leverage *tensor.narrow* to get a subset of tensor to allow for handling of very large tensors with small persisting buffers. For Embedding with KV backend, under the hood, TBE with need to provide a handle, implementing at least the following APIs:
+
+- *Narrow*: This API will provide a chunk of the underlying embedding, given the offset and size. The corresponding keys need to be read, and copied into a tensor, and passed to the caller.
+- *View*: This is to ensure we have a linear view of the tensor. Note that the underlying tensor we support has dim \[hash\_size, emb\_dim\]. The override can only implement view for bytes, numel, or em\_dim.
+- *Element\_size, Numel, nbytes:* to provide the meta data on the size of the underlying tensor sp the caller can perform book keeping
diff --git a/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png b/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png
new file mode 100644
index 000000000..cb6ac6225
Binary files /dev/null and b/rfc/RFC-0002-assets/kv_tbe_pipeline_prefetching.png differ
diff --git a/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png b/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png
new file mode 100644
index 000000000..51ed2a933
Binary files /dev/null and b/rfc/RFC-0002-assets/kv_tbe_prefetch_workflow.png differ
diff --git a/rfc/RFC-0002-assets/kv_tbe_training_high_level.png b/rfc/RFC-0002-assets/kv_tbe_training_high_level.png
new file mode 100644
index 000000000..fd1b95da1
Binary files /dev/null and b/rfc/RFC-0002-assets/kv_tbe_training_high_level.png differ
diff --git a/setup.py b/setup.py
index 63d3aa30d..b0e3b2477 100644
--- a/setup.py
+++ b/setup.py
@@ -7,97 +7,115 @@
import argparse
import os
-import random
-import re
+import subprocess
import sys
-from datetime import date
+from pathlib import Path
from typing import List
from setuptools import find_packages, setup
+ROOT_DIR = Path(__file__).parent.resolve()
-def get_version():
- # get version string from version.py
- # TODO: ideally the version.py should be generated when setup is run
- version_file = os.path.join(os.path.dirname(__file__), "version.py")
- version_regex = r"__version__ = ['\"]([^'\"]*)['\"]"
- with open(version_file, "r") as f:
- version = re.search(version_regex, f.read(), re.M).group(1)
- return version
+def _get_version():
+ try:
+ cmd = ["git", "rev-parse", "HEAD"]
+ sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode("ascii").strip()
+ except Exception:
+ sha = None
-def get_nightly_version():
- today = date.today()
- return f"{today.year}.{today.month}.{today.day}"
+ if "BUILD_VERSION" in os.environ:
+ version = os.environ["BUILD_VERSION"]
+ else:
+ with open(os.path.join(ROOT_DIR, "version.txt"), "r") as f:
+ version = f.readline().strip()
+ if sha is not None and "OFFICIAL_RELEASE" not in os.environ:
+ version += "+" + sha[:7]
+
+ if sha is None:
+ sha = "Unknown"
+ return version, sha
+
+
+def _export_version(version, sha):
+ version_path = ROOT_DIR / "torchrec" / "version.py"
+ with open(version_path, "w") as fileobj:
+ fileobj.write("__version__ = '{}'\n".format(version))
+ fileobj.write("git_version = {}\n".format(repr(sha)))
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchrec setup")
- parser.add_argument(
- "--package_name",
- type=str,
- default="torchrec",
- help="the name of this output wheel",
- )
return parser.parse_known_args(argv)
def main(argv: List[str]) -> None:
args, unknown = parse_args(argv)
- # Set up package name and version
- name = args.package_name
- is_nightly = "nightly" in name
- is_test = "test" in name
-
with open(
os.path.join(os.path.dirname(__file__), "README.MD"), encoding="utf8"
) as f:
readme = f.read()
with open(
- os.path.join(os.path.dirname(__file__), "requirements.txt"), encoding="utf8"
+ os.path.join(os.path.dirname(__file__), "install-requirements.txt"),
+ encoding="utf8",
) as f:
reqs = f.read()
install_requires = reqs.strip().split("\n")
- version = get_nightly_version() if is_nightly else get_version()
-
- if not is_nightly:
- if "fbgemm-gpu-nightly" in install_requires:
- install_requires.remove("fbgemm-gpu-nightly")
- install_requires.append("fbgemm-gpu")
-
- if is_test:
- version = (f"0.0.{random.randint(0, 1000)}",)
- print(f"-- {name} building version: {version}")
-
- packages = find_packages(exclude=("*tests",))
+ version, sha = _get_version()
+ _export_version(version, sha)
+
+ print(f"-- torchrec building version: {version}")
+
+ packages = find_packages(
+ exclude=(
+ "*tests",
+ "*test",
+ "examples",
+ "*examples.*",
+ "*benchmarks",
+ "*build",
+ "*rfc",
+ )
+ )
sys.argv = [sys.argv[0]] + unknown
setup(
# Metadata
- name=name,
+ name="torchrec",
version=version,
author="TorchRec Team",
author_email="packages@pytorch.org",
- description="Pytorch domain library for recommendation systems",
+ maintainer="TroyGarden",
+ maintainer_email="hhy@meta.com",
+ description="TorchRec: Pytorch library for recommendation systems",
long_description=readme,
long_description_content_type="text/markdown",
url="https://github.com/pytorch/torchrec",
license="BSD-3",
- keywords=["pytorch", "recommendation systems", "sharding"],
- python_requires=">=3.7",
+ keywords=[
+ "pytorch",
+ "recommendation systems",
+ "sharding",
+ "distributed training",
+ ],
+ python_requires=">=3.9",
install_requires=install_requires,
packages=packages,
zip_safe=False,
# PyPI package information.
classifiers=[
- "Development Status :: 4 - Beta",
+ "Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
diff --git a/test/cpp/dynamic_embedding/CMakeLists.txt b/test/cpp/dynamic_embedding/CMakeLists.txt
index 056bfd985..53d2c7c02 100644
--- a/test/cpp/dynamic_embedding/CMakeLists.txt
+++ b/test/cpp/dynamic_embedding/CMakeLists.txt
@@ -14,3 +14,8 @@ add_tde_test(bits_op_test bits_op_test.cpp)
add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp)
add_tde_test(random_bits_generator_test random_bits_generator_test.cpp)
add_tde_test(mixed_lfu_lru_strategy_test mixed_lfu_lru_strategy_test.cpp)
+add_tde_test(notification_test notification_test.cpp)
+
+if (BUILD_REDIS_IO)
+ add_subdirectory(redis)
+endif()
diff --git a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
index f8526e575..456d07e1e 100644
--- a/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
+++ b/test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp
@@ -12,8 +12,8 @@
namespace torchrec {
TEST(TDE, order) {
MixedLFULRUStrategy::Record a;
- a.time_ = 1;
- a.freq_power_ = 31;
+ a.time = 1;
+ a.freq_power = 31;
uint32_t i32 = a.ToUint32();
ASSERT_EQ(0xF8000001, i32);
}
@@ -23,42 +23,41 @@ TEST(TDE, MixedLFULRUStrategy_Evict) {
{
records.emplace_back();
records.back().first = 1;
- records.back().second.time_ = 100;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 100;
+ records.back().second.freq_power = 2;
}
{
records.emplace_back();
records.back().first = 2;
- records.back().second.time_ = 10;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 10;
+ records.back().second.freq_power = 2;
}
{
records.emplace_back();
records.back().first = 3;
- records.back().second.time_ = 100;
- records.back().second.freq_power_ = 1;
+ records.back().second.time = 100;
+ records.back().second.freq_power = 1;
}
{
records.emplace_back();
records.back().first = 4;
- records.back().second.time_ = 150;
- records.back().second.freq_power_ = 2;
+ records.back().second.time = 150;
+ records.back().second.freq_power = 2;
}
size_t offset_{0};
- auto ids = MixedLFULRUStrategy::evict(
- [&offset_,
- &records]() -> std::optional {
+ MixedLFULRUStrategy strategy;
+ auto ids = strategy.evict(
+ [&offset_, &records]() -> std::optional {
if (offset_ == records.size()) {
return std::nullopt;
}
auto record = records[offset_++];
- MixedLFULRUStrategy::lxu_record_t ext_type =
- *reinterpret_cast(
- &record.second);
- return MixedLFULRUStrategy::transformer_record_t{
- .global_id_ = record.first,
- .cache_id_ = 0,
- .lxu_record_ = ext_type,
+ lxu_record_t ext_type =
+ *reinterpret_cast(&record.second);
+ return record_t{
+ .global_id = record.first,
+ .cache_id = 0,
+ .lxu_record = ext_type,
};
},
3);
@@ -73,12 +72,12 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
constexpr static size_t n_iter = 1000000;
MixedLFULRUStrategy strategy;
strategy.update_time(10);
- MixedLFULRUStrategy::lxu_record_t val;
+ lxu_record_t val;
{
val = strategy.update(0, 0, std::nullopt);
auto record = reinterpret_cast(&val);
- ASSERT_EQ(record->freq_power_, 5);
- ASSERT_EQ(record->time_, 10);
+ ASSERT_EQ(record->freq_power, 5);
+ ASSERT_EQ(record->time, 10);
}
uint32_t freq_power_5_cnt = 0;
@@ -87,13 +86,13 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
for (size_t i = 0; i < n_iter; ++i) {
auto tmp = strategy.update(0, 0, val);
auto record = reinterpret_cast(&tmp);
- ASSERT_EQ(record->time_, 10);
- if (record->freq_power_ == 5) {
+ ASSERT_EQ(record->time, 10);
+ if (record->freq_power == 5) {
++freq_power_5_cnt;
- } else if (record->freq_power_ == 6) {
+ } else if (record->freq_power == 6) {
++freq_power_6_cnt;
} else {
- ASSERT_TRUE(record->freq_power_ == 5 || record->freq_power_ == 6);
+ ASSERT_TRUE(record->freq_power == 5 || record->freq_power == 6);
}
}
diff --git a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
index a037e7b89..f8ba6e82a 100644
--- a/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
+++ b/test/cpp/dynamic_embedding/naive_id_transformer_test.cpp
@@ -12,8 +12,7 @@
namespace torchrec {
TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(16);
+ NaiveIDTransformer> transformer(16);
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {0, 1, 0, 2, 1};
@@ -24,8 +23,7 @@ TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
}
TEST(tde, NaiveThreadedIDTransformer_Full) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(4);
+ NaiveIDTransformer> transformer(4);
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {0, 1, 2, 3, -1};
@@ -37,8 +35,7 @@ TEST(tde, NaiveThreadedIDTransformer_Full) {
}
TEST(tde, NaiveThreadedIDTransformer_Evict) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(4);
+ NaiveIDTransformer> transformer(4);
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
int64_t cache_ids[5];
@@ -60,8 +57,7 @@ TEST(tde, NaiveThreadedIDTransformer_Evict) {
}
TEST(tde, NaiveThreadedIDTransformer_Iterator) {
- using Tag = int32_t;
- NaiveIDTransformer> transformer(16);
+ NaiveIDTransformer> transformer(16);
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
int64_t cache_ids[5];
int64_t expected_cache_ids[5] = {3, 4, 3, 5, 4};
diff --git a/test/cpp/dynamic_embedding/notification_test.cpp b/test/cpp/dynamic_embedding/notification_test.cpp
new file mode 100644
index 000000000..f916678d8
--- /dev/null
+++ b/test/cpp/dynamic_embedding/notification_test.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+namespace torchrec {
+TEST(TDE, notification) {
+ Notification notification;
+ std::thread th([&] { notification.done(); });
+ notification.wait();
+ th.join();
+}
+} // namespace torchrec
diff --git a/test/cpp/dynamic_embedding/redis/CMakeLists.txt b/test/cpp/dynamic_embedding/redis/CMakeLists.txt
new file mode 100644
index 000000000..b1c2f0c7b
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+function(add_redis_test NAME)
+ add_executable(${NAME} ${ARGN})
+ target_link_libraries(${NAME} redis_io tde_cpp_objs gtest gtest_main)
+ add_test(NAME ${NAME} COMMAND ${NAME})
+endfunction()
+
+# TODO: Need start a empty redis-server
+# on 127.0.0.1:6379 before run *redis*_test.
+add_redis_test(redis_io_test redis_io_test.cpp)
+add_redis_test(url_test url_test.cpp)
diff --git a/test/cpp/dynamic_embedding/redis/redis_io_test.cpp b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp
new file mode 100644
index 000000000..f5c77c257
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/redis_io_test.cpp
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+namespace torchrec::redis {
+
+TEST(TDE, redis_Option) {
+ auto opt = parse_option(
+ "192.168.3.1:3948/?db=3&&num_threads=2&&timeout=3s&&chunk_size=3000");
+ ASSERT_EQ(opt.host, "192.168.3.1");
+ ASSERT_EQ(opt.port, 3948);
+ ASSERT_EQ(opt.db, 3);
+ ASSERT_EQ(opt.num_io_threads, 2);
+ ASSERT_EQ(opt.chunk_size, 3000);
+ ASSERT_EQ(opt.timeout_ms, 3000);
+ ASSERT_TRUE(opt.prefix.empty());
+}
+
+TEST(TDE, redis_Option_ParseError) {
+ ASSERT_ANY_THROW(
+ parse_option("192.168.3.1:3948/?db=3&&no_opt=3000&&num_threads=2"));
+ ASSERT_ANY_THROW(parse_option("192.168.3.1:3948/?timeout=3d"));
+}
+
+struct FetchContext {
+ Notification* notification_;
+ std::function on_data_;
+};
+
+TEST(TDE, redis_push_fetch) {
+ auto opt = parse_option("127.0.0.1:6379");
+ Redis redis(opt);
+
+ constexpr static int64_t global_ids[] = {1, 3, 4};
+ constexpr static uint32_t os_ids[] = {0};
+ constexpr static float params[] = {1, 2, 3, 4, 5, 9, 8, 1};
+ constexpr static uint64_t offsets[] = {
+ 0 * sizeof(float),
+ 2 * sizeof(float),
+ 4 * sizeof(float),
+ 6 * sizeof(float),
+ 8 * sizeof(float)};
+
+ Notification notification;
+
+ IOPushParameter push{
+ .table_name = "table",
+ .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]),
+ .global_ids = global_ids,
+ .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]),
+ .optimizer_state_ids = os_ids,
+ .num_offsets = sizeof(offsets) / sizeof(offsets[0]),
+ .offsets = offsets,
+ .data = params,
+ .on_complete_context = ¬ification,
+ .on_push_complete =
+ +[](void* ctx) {
+ auto* notification = reinterpret_cast(ctx);
+ notification->done();
+ },
+ };
+ redis.push(push);
+
+ notification.wait();
+
+ notification.clear();
+
+ FetchContext ctx{
+ .notification_ = ¬ification,
+ .on_data_ =
+ [&](uint32_t offset, uint32_t os_id, void* data, uint32_t len) {
+ ASSERT_EQ(os_id, 0);
+ uint32_t param_len = 2;
+ ASSERT_EQ(len, sizeof(float) * param_len);
+ auto actual =
+ std::span(reinterpret_cast(data), 2);
+
+ auto expect = std::span(
+ reinterpret_cast(¶ms[offset * param_len]), 2);
+
+ ASSERT_EQ(expect[0], actual[0]);
+ ASSERT_EQ(expect[1], actual[1]);
+ }};
+
+ IOFetchParameter fetch{
+ .table_name = "table",
+ .num_global_ids = sizeof(global_ids) / sizeof(global_ids[0]),
+ .global_ids = global_ids,
+ .num_optimizer_states = sizeof(os_ids) / sizeof(os_ids[0]),
+ .on_complete_context = &ctx,
+ .on_global_id_fetched =
+ +[](void* ctx,
+ uint32_t offset,
+ uint32_t os_id,
+ void* data,
+ uint32_t len) {
+ auto c = reinterpret_cast(ctx);
+ c->on_data_(offset, os_id, data, len);
+ },
+ .on_all_fetched =
+ +[](void* ctx) {
+ auto c = reinterpret_cast(ctx);
+ c->notification_->done();
+ }};
+ redis.fetch(fetch);
+ notification.wait();
+}
+} // namespace torchrec::redis
diff --git a/test/cpp/dynamic_embedding/redis/url_test.cpp b/test/cpp/dynamic_embedding/redis/url_test.cpp
new file mode 100644
index 000000000..6f37744c7
--- /dev/null
+++ b/test/cpp/dynamic_embedding/redis/url_test.cpp
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+
+namespace torchrec::url_parser::rules {
+
+TEST(TDE, url) {
+ auto url = parse_url("http://github.com/www.qq.com/?a=b&&c=d");
+ ASSERT_EQ(url.host, "www.qq.com");
+ ASSERT_TRUE(url.param.has_value());
+ ASSERT_EQ("a=b&&c=d", url.param.value());
+}
+
+} // namespace torchrec::url_parser::rules
diff --git a/test_installation.py b/test_installation.py
index cbef42766..48328a389 100644
--- a/test_installation.py
+++ b/test_installation.py
@@ -18,6 +18,7 @@
from torchrec.models.dlrm import DLRM
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import KeyedOptimizerWrapper
+from torchrec.optim.optimizers import in_backward_optimizer_filter
if sys.platform not in ["linux", "linux2"]:
raise EnvironmentError(
@@ -118,7 +119,7 @@ def main(argv: List[str]) -> None:
device=device,
)
optimizer = KeyedOptimizerWrapper(
- dict(model.named_parameters()),
+ dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: torch.optim.SGD(params, lr=0.01),
)
diff --git a/tools/lint/black_linter.py b/tools/lint/black_linter.py
index cfdc3d4e8..7c9a75f9c 100644
--- a/tools/lint/black_linter.py
+++ b/tools/lint/black_linter.py
@@ -176,11 +176,11 @@ def main() -> None:
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
- level=logging.NOTSET
- if args.verbose
- else logging.DEBUG
- if len(args.filenames) < 1000
- else logging.INFO,
+ level=(
+ logging.NOTSET
+ if args.verbose
+ else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
+ ),
stream=sys.stderr,
)
diff --git a/torchrec/__init__.py b/torchrec/__init__.py
index ddead68c9..29258f6f0 100644
--- a/torchrec/__init__.py
+++ b/torchrec/__init__.py
@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# pyre-strict
+
import torchrec.distributed # noqa
import torchrec.quant # noqa
from torchrec.fx import tracer # noqa
@@ -25,3 +27,10 @@
KeyedTensor,
)
from torchrec.streamable import Multistreamable, Pipelineable # noqa
+
+try:
+ # pyre-ignore[21]
+ # @manual=//torchrec/fb:version
+ from .version import __version__, github_version # noqa
+except ImportError:
+ pass
diff --git a/torchrec/csrc/dynamic_embedding/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/CMakeLists.txt
index ca6fa20b5..f3bd73ad5 100644
--- a/torchrec/csrc/dynamic_embedding/CMakeLists.txt
+++ b/torchrec/csrc/dynamic_embedding/CMakeLists.txt
@@ -6,12 +6,19 @@
add_library(tde_cpp_objs
OBJECT
+ bind.cpp
+ id_transformer_wrapper.cpp
+ ps.cpp
details/clz_impl.cpp
details/ctz_impl.cpp
details/random_bits_generator.cpp
- details/mixed_lfu_lru_strategy.cpp
details/io_registry.cpp
- details/io.cpp)
+ details/io.cpp
+ details/notification.cpp)
+
+if (BUILD_REDIS_IO)
+ add_subdirectory(details/redis)
+endif()
target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS})
diff --git a/torchrec/csrc/dynamic_embedding/bind.cpp b/torchrec/csrc/dynamic_embedding/bind.cpp
new file mode 100644
index 000000000..e2faa9bbd
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/bind.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+#include
+#include
+#include
+
+namespace torchrec {
+TORCH_LIBRARY(tde, m) {
+ m.def("register_io", [](const std::string& name) {
+ IORegistry::Instance().register_plugin(name.c_str());
+ });
+
+ m.class_("TransformResult")
+ .def_readonly("success", &TransformResult::success)
+ .def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch);
+
+ m.class_("IDTransformer")
+ .def(torch::init())
+ .def("transform", &IDTransformerWrapper::transform)
+ .def("evict", &IDTransformerWrapper::evict)
+ .def("save", &IDTransformerWrapper::save);
+
+ m.class_("LocalShardList")
+ .def(torch::init([]() { return c10::make_intrusive(); }))
+ .def("append", &LocalShardList::emplace_back);
+
+ m.class_("FetchHandle").def("wait", &FetchHandle::wait);
+
+ m.class_("PS")
+ .def(torch::init<
+ std::string,
+ c10::intrusive_ptr,
+ int64_t,
+ int64_t,
+ std::string,
+ int64_t>())
+ .def("fetch", &PS::fetch)
+ .def("evict", &PS::evict);
+}
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/id_transformer.h b/torchrec/csrc/dynamic_embedding/details/id_transformer.h
new file mode 100644
index 000000000..3c826c840
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/id_transformer.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include
+#include
+
+namespace torchrec {
+
+namespace transform_default {
+inline lxu_record_t no_update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional record) {
+ return record.value_or(lxu_record_t{});
+};
+
+inline void no_fetch(int64_t global_id, int64_t cache_id) {}
+} // namespace transform_default
+
+class IDTransformer {
+ public:
+ /**
+ * Transform global ids to cache ids
+ *
+ * @tparam Update Update the eviction strategy tag type. Update LXU Record
+ * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used
+ * by dynamic embedding parameter server.
+ *
+ * @param global_ids Global ID vector
+ * @param cache_ids [out] Cache ID vector
+ * @param update update lambda. See `Update` doc.
+ * @param fetch fetch lambda. See `Fetch` doc.
+ * @return true if all transformed, otherwise need eviction.
+ */
+ virtual bool transform(
+ std::span global_ids,
+ std::span cache_ids,
+ update_t update = transform_default::no_update,
+ fetch_t fetch = transform_default::no_fetch) = 0;
+
+ /**
+ * Evict global ids from the transformer
+ *
+ * @param global_ids Global IDs to evict.
+ */
+ virtual void evict(std::span global_ids) = 0;
+
+ /**
+ * Create an iterator of the id transformer, a possible usecase is:
+ *
+ * auto iterator = transformer.iterator();
+ * auto record = iterator();
+ * while (record.has_value()) {
+ * // do sth with the record
+ * // ...
+ * // get next record
+ * auto record = iterator();
+ * }
+ *
+ * @return the iterator created.
+ */
+ virtual iterator_t iterator() const = 0;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io.cpp b/torchrec/csrc/dynamic_embedding/details/io.cpp
index cf9556d7e..1c205254b 100644
--- a/torchrec/csrc/dynamic_embedding/details/io.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/io.cpp
@@ -8,7 +8,7 @@
#include
-namespace tde::details {
+namespace torchrec {
static constexpr std::string_view k_schema_separator = "://";
@@ -77,7 +77,7 @@ static void on_all_fetched(void* ctx) {
delete c;
}
-void IO::pull(
+void IO::fetch(
const std::string& table_name,
std::span global_ids,
std::span col_ids,
@@ -93,7 +93,7 @@ void IO::pull(
ctx->tensors.resize(
global_ids.size() * std::max(col_ids.size(), static_cast(1)));
- IOPullParameter param{
+ IOFetchParameter param{
.table_name = table_name.c_str(),
.num_cols = static_cast(col_ids.size()),
.num_global_ids = static_cast(global_ids.size()),
@@ -104,7 +104,7 @@ void IO::pull(
.on_all_fetched = on_all_fetched,
};
param.on_complete_context = ctx.release();
- provider_.pull(instance_, param);
+ provider_.fetch(instance_, param);
}
struct PushContext {
@@ -135,7 +135,7 @@ void IO::push(
.col_ids = col_ids.data(),
.global_ids = global_ids.data(),
.num_optimizer_states = static_cast(os_ids.size()),
- .optimizer_stats_ids = os_ids.data(),
+ .optimizer_state_ids = os_ids.data(),
.num_offsets = static_cast(offsets.size()),
.offsets = offsets.data(),
.data = data.data(),
@@ -145,4 +145,4 @@ void IO::push(
provider_.push(instance_, param);
}
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io.h b/torchrec/csrc/dynamic_embedding/details/io.h
index ae2d19195..2f4ec64ac 100644
--- a/torchrec/csrc/dynamic_embedding/details/io.h
+++ b/torchrec/csrc/dynamic_embedding/details/io.h
@@ -12,7 +12,7 @@
#include
#include
-namespace tde::details {
+namespace torchrec {
class IO {
public:
@@ -44,7 +44,7 @@ class IO {
* copied inside, so it is safe to free `col_ids`/`global_ids` before
* `on_fetch_complete`.
*/
- void pull(
+ void fetch(
const std::string& table_name,
std::span global_ids,
std::span col_ids,
@@ -89,4 +89,4 @@ class IO {
void* instance_{};
};
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_parameter.h b/torchrec/csrc/dynamic_embedding/details/io_parameter.h
new file mode 100644
index 000000000..ab92db9cd
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/io_parameter.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+namespace torchrec {
+
+using GlobalIDFetchCallback = void (*)(
+ void* ctx,
+ uint32_t offset,
+ uint32_t optimizer_state,
+ void* data,
+ uint32_t data_len);
+
+struct IOFetchParameter {
+ const char* table_name;
+ uint32_t num_cols;
+ uint32_t num_global_ids;
+ const int64_t* col_ids;
+ const int64_t* global_ids;
+ uint32_t num_optimizer_states;
+ void* on_complete_context;
+ GlobalIDFetchCallback on_global_id_fetched;
+ void (*on_all_fetched)(void* ctx);
+};
+
+struct IOPushParameter {
+ const char* table_name;
+ uint32_t num_cols;
+ uint32_t num_global_ids;
+ const int64_t* col_ids;
+ const int64_t* global_ids;
+ uint32_t num_optimizer_states;
+ const uint32_t* optimizer_state_ids;
+ uint32_t num_offsets;
+ const uint64_t* offsets;
+ const void* data;
+ void* on_complete_context;
+ void (*on_push_complete)(void* ctx);
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
index 614da6d2d..b3744dedf 100644
--- a/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/io_registry.cpp
@@ -10,7 +10,7 @@
#include
#include
-namespace tde::details {
+namespace torchrec {
void IORegistry::register_provider(IOProvider provider) {
std::string type = provider.type;
@@ -41,9 +41,9 @@ void IORegistry::register_plugin(const char* filename) {
provider.finalize =
reinterpret_cast(finalize_ptr);
- auto pull_ptr = dlsym(ptr.get(), "IO_Pull");
- TORCH_CHECK(pull_ptr != nullptr, "cannot find IO_Pull symbol");
- provider.pull = reinterpret_cast(pull_ptr);
+ auto fetch_ptr = dlsym(ptr.get(), "IO_Fetch");
+ TORCH_CHECK(fetch_ptr != nullptr, "cannot find IO_Fetch symbol");
+ provider.fetch = reinterpret_cast(fetch_ptr);
auto push_ptr = dlsym(ptr.get(), "IO_Push");
TORCH_CHECK(push_ptr != nullptr, "cannot find IO_Push symbol");
@@ -72,4 +72,4 @@ IORegistry& IORegistry::Instance() {
return instance;
}
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/io_registry.h b/torchrec/csrc/dynamic_embedding/details/io_registry.h
index af727e84a..3feb39f4a 100644
--- a/torchrec/csrc/dynamic_embedding/details/io_registry.h
+++ b/torchrec/csrc/dynamic_embedding/details/io_registry.h
@@ -10,48 +10,17 @@
#include
#include
#include
+#include
#include
#include
#include
-namespace tde::details {
-
-struct IOPullParameter {
- const char* table_name;
- uint32_t num_cols;
- uint32_t num_global_ids;
- const int64_t* col_ids;
- const int64_t* global_ids;
- uint32_t num_optimizer_states;
- void* on_complete_context;
- void (*on_global_id_fetched)(
- void* ctx,
- uint32_t offset,
- uint32_t optimizer_state,
- void* data,
- uint32_t data_len);
- void (*on_all_fetched)(void* ctx);
-};
-
-struct IOPushParameter {
- const char* table_name;
- uint32_t num_cols;
- uint32_t num_global_ids;
- const int64_t* col_ids;
- const int64_t* global_ids;
- uint32_t num_optimizer_states;
- const uint32_t* optimizer_stats_ids;
- uint32_t num_offsets;
- const uint64_t* offsets;
- const void* data;
- void* on_complete_context;
- void (*on_push_complete)(void* ctx);
-};
+namespace torchrec {
struct IOProvider {
const char* type;
void* (*initialize)(const char* cfg);
- void (*pull)(void* instance, IOPullParameter cfg);
+ void (*fetch)(void* instance, IOFetchParameter cfg);
void (*push)(void* instance, IOPushParameter cfg);
void (*finalize)(void*);
};
@@ -75,4 +44,4 @@ class IORegistry {
std::vector dls_;
};
-} // namespace tde::details
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h
new file mode 100644
index 000000000..e18107770
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/lxu_strategy.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+
+namespace torchrec {
+
+class LXUStrategy {
+ public:
+ LXUStrategy() = default;
+ LXUStrategy(const LXUStrategy&) = delete;
+ LXUStrategy(LXUStrategy&& o) noexcept = default;
+
+ virtual void update_time(uint32_t time) = 0;
+ virtual int64_t time(lxu_record_t record) = 0;
+
+ virtual lxu_record_t update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional val) = 0;
+
+ /**
+ * Analysis all ids and returns the num_elems that are most need to evict.
+ * @param iterator Returns each global_id to ExtValue pair. Returns nullopt
+ * when at ends.
+ * @param num_to_evict
+ * @return
+ */
+ virtual std::vector evict(
+ iterator_t iterator,
+ uint64_t num_to_evict) = 0;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp
deleted file mode 100644
index f6c486726..000000000
--- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.cpp
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-#include
-#include
-#include
-
-namespace torchrec {
-MixedLFULRUStrategy::MixedLFULRUStrategy(uint16_t min_used_freq_power)
- : min_lfu_power_(min_used_freq_power), time_(new std::atomic()) {}
-
-void MixedLFULRUStrategy::update_time(uint32_t time) {
- time_->store(time);
-}
-
-MixedLFULRUStrategy::lxu_record_t MixedLFULRUStrategy::update(
- int64_t global_id,
- int64_t cache_id,
- std::optional val) {
- Record r{};
- r.time_ = time_->load();
-
- if (C10_UNLIKELY(!val.has_value())) {
- r.freq_power_ = min_lfu_power_;
- } else {
- auto freq_power = reinterpret_cast(&val.value())->freq_power_;
- bool should_carry = generator_.is_next_n_bits_all_zero(freq_power);
- if (should_carry) {
- ++freq_power;
- }
- r.freq_power_ = freq_power;
- }
- return *reinterpret_cast(&r);
-}
-
-} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
index 85f7bbd75..beb015c18 100644
--- a/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
+++ b/torchrec/csrc/dynamic_embedding/details/mixed_lfu_lru_strategy.h
@@ -7,7 +7,8 @@
*/
#pragma once
-#include
+#include
+#include
#include
#include
#include
@@ -37,49 +38,56 @@ namespace torchrec {
*
* Use `update` to update extended value when every time global id that used.
*/
-class MixedLFULRUStrategy {
+class MixedLFULRUStrategy : public LXUStrategy {
public:
- using lxu_record_t = uint32_t;
- using transformer_record_t = TransformerRecord;
-
- static constexpr std::string_view type_ = "mixed_lru_lfu";
-
/**
* @param min_used_freq_power min usage is 2^min_used_freq_power. Set this to
* avoid recent values evict too fast.
*/
- explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5);
+ explicit MixedLFULRUStrategy(uint16_t min_used_freq_power = 5)
+ : min_lfu_power_(min_used_freq_power),
+ time_(new std::atomic()) {}
MixedLFULRUStrategy(const MixedLFULRUStrategy&) = delete;
MixedLFULRUStrategy(MixedLFULRUStrategy&& o) noexcept = default;
- void update_time(uint32_t time);
- template
- static int64_t time(T record) {
- static_assert(sizeof(T) == sizeof(Record));
- return static_cast(reinterpret_cast(&record)->time_);
+ void update_time(lxu_record_t time) override {
+ time_->store(time);
+ }
+
+ int64_t time(lxu_record_t record) override {
+ return static_cast(reinterpret_cast(&record)->time);
}
- lxu_record_t
- update(int64_t global_id, int64_t cache_id, std::optional val);
+ lxu_record_t update(
+ int64_t global_id,
+ int64_t cache_id,
+ std::optional val) override {
+ Record r{};
+ r.time = time_->load();
+
+ if (!val.has_value()) [[unlikely]] {
+ r.freq_power = min_lfu_power_;
+ } else {
+ auto freq_power = reinterpret_cast(&val.value())->freq_power;
+ bool should_carry = generator_.is_next_n_bits_all_zero(freq_power);
+ if (should_carry) {
+ ++freq_power;
+ }
+ r.freq_power = freq_power;
+ }
+ return *reinterpret_cast(&r);
+ }
struct EvictItem {
- int64_t global_id_;
- lxu_record_t record_;
+ int64_t global_id;
+ lxu_record_t record;
bool operator<(const EvictItem& item) const {
- return record_ < item.record_;
+ return record < item.record;
}
};
- /**
- * Analysis all ids and returns the num_elems that are most need to evict.
- * @param iterator Returns each global_id to ExtValue pair. Returns nullopt
- * when at ends.
- * @param num_to_evict
- * @return
- */
- template
- static std::vector evict(Iterator iterator, uint64_t num_to_evict) {
+ std::vector evict(iterator_t iterator, uint64_t num_to_evict) {
std::priority_queue items;
while (true) {
auto val = iterator();
@@ -87,8 +95,8 @@ class MixedLFULRUStrategy {
break;
}
EvictItem item{
- .global_id_ = val->global_id_,
- .record_ = reinterpret_cast(&val->lxu_record_)->ToUint32(),
+ .global_id = val->global_id,
+ .record = reinterpret_cast(&val->lxu_record)->ToUint32(),
};
if (items.size() == num_to_evict) {
if (!(item < items.top())) {
@@ -105,7 +113,7 @@ class MixedLFULRUStrategy {
result.reserve(items.size());
while (!items.empty()) {
auto item = items.top();
- result.emplace_back(item.global_id_);
+ result.emplace_back(item.global_id);
items.pop();
}
std::reverse(result.begin(), result.end());
@@ -114,17 +122,16 @@ class MixedLFULRUStrategy {
// Record should only be used in unittest or internally.
struct Record {
- uint32_t time_ : 27;
- uint16_t freq_power_ : 5;
+ uint32_t time : 27;
+ uint16_t freq_power : 5;
[[nodiscard]] uint32_t ToUint32() const {
- return time_ | (freq_power_ << (32 - 5));
+ return time | (freq_power << (32 - 5));
}
};
-
- private:
static_assert(sizeof(Record) == sizeof(lxu_record_t));
+ private:
RandomBitsGenerator generator_;
uint16_t min_lfu_power_;
std::unique_ptr> time_;
diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
index 9cf2e694d..8f7718340 100644
--- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
+++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h
@@ -9,33 +9,13 @@
#pragma once
#include
#include
+#include
#include
#include
#include
namespace torchrec {
-namespace transform_default {
-
-template
-inline LXURecord no_update(
- std::optional record,
- int64_t global_id,
- int64_t cache_id) {
- return record.value_or(LXURecord{});
-};
-
-inline void no_fetch(int64_t global_id, int64_t cache_id) {}
-
-} // namespace transform_default
-
-template
-struct TransformerRecord {
- int64_t global_id_;
- int64_t cache_id_;
- LXURecord lxu_record_;
-};
-
/**
* NaiveIDTransformer
*
@@ -43,67 +23,27 @@ struct TransformerRecord {
* @tparam LXURecord The extension type used for eviction strategy.
* @tparam Bitmap The bitmap class to record the free cache ids.
*/
-template >
-class NaiveIDTransformer {
+template >
+class NaiveIDTransformer : public IDTransformer {
public:
- using lxu_record_t = LXURecord;
- using record_t = TransformerRecord;
- static constexpr std::string_view type_ = "naive";
-
explicit NaiveIDTransformer(int64_t num_embedding);
- NaiveIDTransformer(const NaiveIDTransformer&) = delete;
- NaiveIDTransformer(NaiveIDTransformer&&) noexcept =
- default;
+ NaiveIDTransformer(const NaiveIDTransformer&) = delete;
+ NaiveIDTransformer(NaiveIDTransformer&&) noexcept = default;
- /**
- * Transform global ids to cache ids
- *
- * @tparam Update Update the eviction strategy tag type. Update LXU Record
- * @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used
- * by dynamic embedding parameter server.
- *
- * @param global_ids Global ID vector
- * @param cache_ids [out] Cache ID vector
- * @param update update lambda. See `Update` doc.
- * @param fetch fetch lambda. See `Fetch` doc.
- * @return true if all transformed, otherwise need eviction.
- */
- template <
- typename Update = decltype(transform_default::no_update),
- typename Fetch = decltype(transform_default::no_fetch)>
bool transform(
std::span global_ids,
std::span cache_ids,
- Update update = transform_default::no_update,
- Fetch fetch = transform_default::no_fetch);
+ update_t update = transform_default::no_update,
+ fetch_t fetch = transform_default::no_fetch) override;
- /**
- * Evict global ids from the transformer
- *
- * @param global_ids Global IDs to evict.
- */
- void evict(std::span global_ids);
+ void evict(std::span global_ids) override;
- /**
- * Create an iterator of the id transformer, a possible usecase is:
- *
- * auto iterator = transformer.iterator();
- * auto record = iterator();
- * while (record.has_value()) {
- * // do sth with the record
- * // ...
- * // get next record
- * auto record = iterator();
- * }
- *
- * @return the iterator created.
- */
- std::function()> iterator() const;
+ iterator_t iterator() const override;
private:
struct CacheValue {
- int64_t cache_id_;
- LXURecord lxu_record_;
+ int64_t cache_id;
+ lxu_record_t lxu_record;
};
ska::flat_hash_map global_id2cache_value_;
diff --git a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
index cd40f87d7..e0e38b2f7 100644
--- a/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
+++ b/torchrec/csrc/dynamic_embedding/details/naive_id_transformer_impl.h
@@ -6,43 +6,40 @@
* LICENSE file in the root directory of this source tree.
*/
-#pragma once
#include
#include
namespace torchrec {
-template
-inline NaiveIDTransformer::NaiveIDTransformer(
- int64_t num_embedding)
+template
+NaiveIDTransformer::NaiveIDTransformer(int64_t num_embedding)
: bitmap_(num_embedding) {
global_id2cache_value_.reserve(num_embedding);
}
-template
-template
-inline bool NaiveIDTransformer::transform(
+template
+bool NaiveIDTransformer::transform(
std::span global_ids,
std::span cache_ids,
- Update update,
- Fetch fetch) {
+ update_t update,
+ fetch_t fetch) {
for (size_t i = 0; i < global_ids.size(); ++i) {
int64_t global_id = global_ids[i];
auto iter = global_id2cache_value_.find(global_id);
// cache_id is in [0, num_embedding)
int64_t cache_id;
if (iter != global_id2cache_value_.end()) {
- cache_id = iter->second.cache_id_;
- iter->second.lxu_record_ =
- update(iter->second.lxu_record_, global_id, cache_id);
+ cache_id = iter->second.cache_id;
+ iter->second.lxu_record =
+ update(global_id, cache_id, iter->second.lxu_record);
} else {
// The transformer is full.
- if (C10_UNLIKELY(bitmap_.full())) {
+ if (bitmap_.full()) [[unlikely]] {
return false;
}
auto stored_cache_id = bitmap_.next_free_bit();
cache_id = stored_cache_id;
- LXURecord record = update(std::nullopt, global_id, cache_id);
+ lxu_record_t record = update(global_id, cache_id, std::nullopt);
global_id2cache_value_.emplace(
global_id, CacheValue{stored_cache_id, record});
fetch(global_id, cache_id);
@@ -52,30 +49,28 @@ inline bool NaiveIDTransformer::transform(
return true;
}
-template
-inline void NaiveIDTransformer::evict(
- std::span global_ids) {
+template
+void NaiveIDTransformer::evict(std::span global_ids) {
for (const int64_t global_id : global_ids) {
auto iter = global_id2cache_value_.find(global_id);
if (iter == global_id2cache_value_.end()) {
continue;
}
- int64_t cache_id = iter->second.cache_id_;
+ int64_t cache_id = iter->second.cache_id;
global_id2cache_value_.erase(iter);
bitmap_.free_bit(cache_id);
}
}
-template
-inline auto NaiveIDTransformer::iterator() const
- -> std::function()> {
+template
+iterator_t NaiveIDTransformer::iterator() const {
auto iter = global_id2cache_value_.begin();
return [iter, this]() mutable -> std::optional {
if (iter != global_id2cache_value_.end()) {
auto record = record_t{
- .global_id_ = iter->first,
- .cache_id_ = iter->second.cache_id_,
- .lxu_record_ = iter->second.lxu_record_,
+ .global_id = iter->first,
+ .cache_id = iter->second.cache_id,
+ .lxu_record = iter->second.lxu_record,
};
iter++;
return record;
diff --git a/torchrec/csrc/dynamic_embedding/details/notification.cpp b/torchrec/csrc/dynamic_embedding/details/notification.cpp
new file mode 100644
index 000000000..297b05393
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/notification.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "notification.h"
+
+namespace torchrec {
+void Notification::done() {
+ {
+ std::lock_guard guard(mu_);
+ set_ = true;
+ }
+ cv_.notify_all();
+}
+void Notification::wait() {
+ std::unique_lock lock(mu_);
+ cv_.wait(lock, [this] { return set_; });
+}
+
+void Notification::clear() {
+ set_ = false;
+}
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/notification.h b/torchrec/csrc/dynamic_embedding/details/notification.h
new file mode 100644
index 000000000..5d6c04f03
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/notification.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include
+
+namespace torchrec {
+
+/**
+ * Multi-thread notification
+ */
+class Notification : public torch::CustomClassHolder {
+ public:
+ Notification() = default;
+
+ void done();
+ void wait();
+
+ /**
+ * Clear the set status.
+ *
+ * NOTE: Clear is not thread-safe.
+ */
+ void clear();
+
+ private:
+ bool set_{false};
+ std::mutex mu_;
+ std::condition_variable cv_;
+};
+
+} // namespace torchrec
diff --git a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
index 4b307842d..0fada452d 100644
--- a/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
+++ b/torchrec/csrc/dynamic_embedding/details/random_bits_generator.cpp
@@ -6,17 +6,16 @@
* LICENSE file in the root directory of this source tree.
*/
-#include
#include
#include
namespace torchrec {
bool BitScanner::is_next_n_bits_all_zero(uint16_t& n_bits) {
- if (C10_UNLIKELY((n_bits == 0))) {
+ if ((n_bits == 0)) [[unlikely]] {
return true;
}
- if (C10_UNLIKELY(array_idx_ == size_)) {
+ if (array_idx_ == size_) [[unlikely]] {
return true;
}
@@ -88,7 +87,7 @@ bool RandomBitsGenerator::is_next_n_bits_all_zero(uint16_t n_bits) {
return false;
}
- if (C10_UNLIKELY(n_bits != 0)) {
+ if (n_bits != 0) [[unlikely]] {
return is_next_n_bits_all_zero(n_bits);
} else {
return true;
diff --git a/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt
new file mode 100644
index 000000000..a9771391c
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/redis/CMakeLists.txt
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+FetchContent_Declare(
+ hiredis
+ GIT_REPOSITORY https://github.com/redis/hiredis.git
+ GIT_TAG 06be7ff312a78f69237e5963cc7d24bc84104d3b
+)
+
+FetchContent_GetProperties(hiredis)
+if(NOT hiredis_POPULATED)
+ # Do not include hiredis in install targets
+ FetchContent_Populate(hiredis)
+ set(DISABLE_TESTS ON CACHE BOOL "Disable tests for hiredis")
+ add_subdirectory(
+ ${hiredis_SOURCE_DIR} ${hiredis_BINARY_DIR} EXCLUDE_FROM_ALL)
+endif()
+
+add_library(redis_io SHARED redis_io.cpp)
+target_include_directories(
+ redis_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../)
+target_include_directories(redis_io PUBLIC ${TORCH_INCLUDE_DIRS})
+target_compile_options(redis_io PUBLIC -fPIC)
+target_link_libraries(redis_io PUBLIC hiredis::hiredis_static)
diff --git a/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp
new file mode 100644
index 000000000..71f04fd08
--- /dev/null
+++ b/torchrec/csrc/dynamic_embedding/details/redis/redis_io.cpp
@@ -0,0 +1,505 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include
+
+namespace torchrec::redis {
+
+int parse_integer(std::string_view param_str, std::string_view param_key) {
+ return std::stoi(std::string(param_str.substr(param_key.size())));
+}
+
+uint32_t parse_duration(
+ std::string_view param_str,
+ std::string_view param_key) {
+ auto param_value = param_str.substr(param_key.size());
+ if (param_value.empty()) {
+ throw std::invalid_argument("no value for " + std::string(param_str));
+ }
+ double duration;
+ if (param_value.ends_with("ms")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 2)));
+ } else if (param_value.ends_with("s")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 1))) *
+ 1000;
+ } else if (param_value.ends_with("m")) {
+ duration =
+ std::stod(std::string(param_value.substr(0, param_value.size() - 1))) *
+ 1000 * 60;
+ } else {
+ throw std::invalid_argument(
+ "no supported time unit (ms, s, m) in " + std::string(param_str));
+ }
+ return static_cast(duration);
+}
+
+Option parse_option(std::string_view config_str) {
+ Option option;
+ url_parser::Url url = url_parser::parse_url(http://github.com/config_str);
+
+ if (url.authority.has_value()) {
+ option.username = std::move(url.authority->username);
+ option.password = std::move(url.authority->password);
+ }
+
+ option.host = std::move(url.host);
+
+ if (url.port.has_value()) {
+ option.port = url.port.value();
+ }
+
+ if (url.param.has_value()) {
+ std::string_view param_str = url.param.value();
+ while (!param_str.empty()) {
+ auto and_pos = param_str.find("&&");
+ std::string_view single_param_str;
+ if (and_pos != std::string_view::npos) {
+ single_param_str = param_str.substr(0, and_pos);
+ param_str = param_str.substr(and_pos + 2);
+ } else {
+ single_param_str = param_str;
+ param_str = "";
+ }
+
+ if (single_param_str.starts_with("num_threads=")) {
+ option.num_io_threads = parse_integer(single_param_str, "num_threads=");
+ } else if (single_param_str.starts_with("db=")) {
+ option.db = parse_integer(single_param_str, "db=");
+ } else if (single_param_str.starts_with("prefix=")) {
+ option.prefix = single_param_str.substr(std::string("prefix=").size());
+ } else if (single_param_str.starts_with("timeout=")) {
+ option.timeout_ms = parse_duration(single_param_str, "timeout=");
+ } else if (single_param_str.starts_with("heartbeat=")) {
+ option.heart_beat_interval_ms =
+ parse_duration(single_param_str, "heartbeat=");
+ } else if (single_param_str.starts_with("retry_limit=")) {
+ option.retry_limit = parse_integer(single_param_str, "retry_limit=");
+ } else if (single_param_str.starts_with("chunk_size=")) {
+ option.chunk_size = parse_integer(single_param_str, "chunk_size=");
+ } else {
+ throw std::invalid_argument(
+ "unknown parameter: " + std::string(single_param_str));
+ }
+ }
+ }
+
+ return option;
+}
+
+Redis::Redis(Option opt) : opt_(std::move(opt)) {
+ TORCH_CHECK(opt_.num_io_threads != 0, "num_io_threads must not be empty");
+ TORCH_CHECK(
+ opt_.heart_beat_interval_ms != 0,
+ "heart beat interval must not be zero.");
+ for (size_t i = 0; i < opt_.num_io_threads; ++i) {
+ start_thread();
+ }
+}
+
+void Redis::start_thread() {
+ auto connection = connect();
+ heartbeat(connection);
+
+ io_threads_.emplace_back(
+ [connection = std::move(connection), this]() mutable {
+ std::chrono::milliseconds heart_beat(opt_.heart_beat_interval_ms);
+ while (true) {
+ std::function todo;
+ bool heartbeat_timeout;
+ {
+ std::unique_lock lock(this->jobs_mutex_);
+ heartbeat_timeout = !jobs_not_empty_.wait_for(
+ lock, heart_beat, [this] { return !jobs_.empty(); });
+ if (!heartbeat_timeout) {
+ todo = std::move(jobs_.front());
+ jobs_.pop_front();
+ }
+ }
+
+ if (heartbeat_timeout) {
+ heartbeat(connection);
+ continue;
+ }
+
+ if (!todo) {
+ break;
+ }
+ todo(connection);
+ }
+ });
+}
+
+void Redis::heartbeat(helper::ContextPtr& connection) {
+ for (uint32_t retry = 0; retry < opt_.retry_limit; ++retry) {
+ try {
+ auto reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "PING")));
+ TORCH_CHECK(
+ reply && reply->type == REDIS_REPLY_STRING,
+ "Ping should return string");
+ auto rsp = std::string_view(reply->str, reply->len);
+ TORCH_CHECK(rsp == "PONG", "ping/pong error");
+ } catch (...) {
+ // reconnect if heart beat error
+ connection = connect();
+ }
+ }
+}
+
+helper::ContextPtr Redis::connect() const {
+ helper::ContextPtr connection;
+ if (opt_.timeout_ms == 0) {
+ connection = helper::ContextPtr(redisConnect(opt_.host.c_str(), opt_.port));
+ } else {
+ struct timeval interval {};
+ interval.tv_sec = opt_.timeout_ms / 1000;
+ interval.tv_usec = opt_.timeout_ms % 1000 * 1000;
+ connection = helper::ContextPtr(
+ redisConnectWithTimeout(opt_.host.c_str(), opt_.port, interval));
+ }
+ TORCH_CHECK(
+ !connection->err,
+ "connect to %s:%d error occurred %s",
+ opt_.host,
+ opt_.port,
+ connection->errstr);
+
+ if (!opt_.password.empty()) {
+ helper::ReplyPtr reply;
+ if (opt_.username.empty()) {
+ reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "AUTH %s", opt_.password.c_str())));
+ } else {
+ reply = helper::ReplyPtr(reinterpret_cast(redisCommand(
+ connection.get(),
+ "AUTH %s %s",
+ opt_.username.c_str(),
+ opt_.password.c_str())));
+ }
+ check_status("auth error", connection, reply);
+ }
+
+ if (opt_.db != 0) {
+ auto reply = helper::ReplyPtr(reinterpret_cast(
+ redisCommand(connection.get(), "SELECT %d", opt_.db)));
+ check_status("select db error", connection, reply);
+ }
+
+ return connection;
+}
+
+Redis::~Redis() {
+ for (uint32_t i = 0; i < opt_.num_io_threads; ++i) {
+ jobs_.emplace_back();
+ }
+ jobs_not_empty_.notify_all();
+ for (auto& th : io_threads_) {
+ th.join();
+ }
+}
+
+static uint32_t CalculateChunkSizeByGlobalIDs(
+ uint32_t chunk_size,
+ uint32_t num_cols,
+ uint32_t num_os) {
+ static constexpr uint32_t low = 1;
+ return std::max(
+ chunk_size / std::max(num_cols, low) / std::max(num_os, low), low);
+}
+
+struct RedisFetchContext {
+ std::atomic num_complete_ids{0};
+ uint32_t chunk_size;
+ std::string table_name;
+ std::vector global_ids;
+ std::vector col_ids;
+ uint32_t num_optimizer_states;
+ void* on_complete_context;
+ void (*on_global_id_fetched)(
+ void* ctx,
+ uint32_t gid_offset,
+ uint32_t optimizer_state,
+ void* data,
+ uint32_t data_len);
+ void (*on_all_fetched)(void* ctx);
+
+ explicit RedisFetchContext(uint32_t chunk_size, IOFetchParameter param)
+ : chunk_size(CalculateChunkSizeByGlobalIDs(
+ chunk_size,
+ param.num_cols,
+ param.num_optimizer_states)),
+ table_name(param.table_name),
+ global_ids(param.global_ids, param.global_ids + param.num_global_ids),
+ num_optimizer_states(param.num_optimizer_states),
+ on_complete_context(param.on_complete_context),
+ on_global_id_fetched(param.on_global_id_fetched),
+ on_all_fetched(param.on_all_fetched) {
+ if (param.num_cols == 0) {
+ col_ids.emplace_back(-1);
+ } else {
+ col_ids =
+ std::vector(param.col_ids, param.col_ids + param.num_cols);
+ }
+ }
+};
+
+void Redis::fetch(IOFetchParameter param) {
+ auto* fetch_param = new RedisFetchContext(opt_.chunk_size, param);
+ {
+ std::lock_guard guard(this->jobs_mutex_);
+ for (uint32_t i = 0; i < param.num_global_ids;
+ i += fetch_param->chunk_size) {
+ jobs_.emplace_back(
+ [i, fetch_param, this](helper::ContextPtr& connection) {
+ do_fetch(i, fetch_param, connection);
+ });
+ }
+ }
+ jobs_not_empty_.notify_all();
+}
+
+void Redis::do_fetch(
+ uint32_t gid_offset,
+ void* fetch_param_void,
+ helper::ContextPtr& connection) const {
+ auto& fetch_param = *reinterpret_cast(fetch_param_void);
+
+ uint32_t end = std::min(
+ gid_offset + fetch_param.chunk_size,
+ static_cast(fetch_param.global_ids.size()));
+
+ auto loop = [&](auto&& callback) {
+ for (uint32_t i = gid_offset; i < end; ++i) {
+ int64_t gid = fetch_param.global_ids[i];
+ for (uint32_t j = 0; j < fetch_param.col_ids.size(); ++j) {
+ auto& col_id = fetch_param.col_ids[j];
+ for (uint32_t os_id = 0; os_id < fetch_param.num_optimizer_states;
+ ++os_id) {
+ callback(i * fetch_param.col_ids.size() + j, gid, col_id, os_id);
+ }
+ }
+ }
+ };
+
+ loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) {
+ redisAppendCommand(
+ connection.get(),
+ "GET %s_table_%s_gid_%d_cid_%d_osid_%d",
+ opt_.prefix.c_str(),
+ fetch_param.table_name.c_str(),
+ gid,
+ col_id,
+ os_id);
+ });
+
+ void* reply;
+ loop([&](uint32_t offset, int64_t gid, uint32_t col_id, uint32_t os_id) {
+ int status = redisGetReply(connection.get(), &reply);
+ TORCH_CHECK(
+ status != REDIS_ERR,
+ "get reply error: %s, from redis %s, %d",
+ connection->errstr,
+ opt_.host,
+ opt_.port);
+ auto reply_ptr = helper::ReplyPtr(reinterpret_cast(reply));
+
+ if (reply_ptr->type == REDIS_REPLY_NIL) {
+ fetch_param.on_global_id_fetched(
+ fetch_param.on_complete_context, offset, os_id, nullptr, 0);
+ } else {
+ fetch_param.on_global_id_fetched(
+ fetch_param.on_complete_context,
+ offset,
+ os_id,
+ reply_ptr->str,
+ reply_ptr->len);
+ }
+ });
+
+ uint32_t n = end - gid_offset;
+ uint32_t target = fetch_param.global_ids.size();
+
+ if (fetch_param.num_complete_ids.fetch_add(n) + n ==
+ target) { // last fetch complete
+ fetch_param.on_all_fetched(fetch_param.on_complete_context);
+ delete &fetch_param;
+ }
+}
+
+struct RedisPushContext {
+ std::atomic num_complete_ids{0};
+ uint32_t chunk_size;
+ std::string table_name;
+ std::span global_ids;
+ std::vector col_ids;
+ std::span os_ids;
+ std::span