From 419afeb1f9e0e179c5e5d1d38696617cf8c7a3c5 Mon Sep 17 00:00:00 2001 From: hhy Date: Wed, 22 Jan 2025 16:04:14 -0800 Subject: [PATCH 01/11] [OSS] fix fbgemm LB_LIBRARY_PATH --- .github/scripts/validate_binaries.sh | 68 ++++++++++++++++++---------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh index 85ad0de47..6750b5b74 100755 --- a/.github/scripts/validate_binaries.sh +++ b/.github/scripts/validate_binaries.sh @@ -7,8 +7,9 @@ export PYTORCH_CUDA_PKG="" +export CONDA_ENV="build_binary" -conda create -y -n build_binary python="${MATRIX_PYTHON_VERSION}" +conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}" conda run -n build_binary python --version @@ -49,41 +50,60 @@ elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then export PYTORCH_URL="https://download.pytorch.org/whl/$%7BCUDA_VERSION%7D" fi + +echo "CU_VERSION: ${CUDA_VERSION}" +echo "MATRIX_CHANNEL: ${MATRIX_CHANNEL}" +echo "CONDA_ENV: ${CONDA_ENV}" + +# shellcheck disable=SC2155 +export CONDA_PREFIX=$(conda run -n "${CONDA_ENV}" printenv CONDA_PREFIX) + +find / -name *cuda* + +if [[ $CUDA_VERSION = cu* ]]; then + # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not + # being able to locate libnvrtc.so + echo "[NOVA] Setting LD_LIBRARY_PATH ..." + conda env config vars set -n ${CONDA_ENV} \ + LD_LIBRARY_PATH="/usr/local/lib:/usr/lib64:${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" +fi + + # install pytorch # switch back to conda once torch nightly is fixed # if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then # export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}" # fi -conda run -n build_binary pip install torch --index-url "$PYTORCH_URL" +conda run -n "${CONDA_ENV}" pip install torch --index-url "$PYTORCH_URL" # install fbgemm -conda run -n build_binary pip install fbgemm-gpu --index-url "$PYTORCH_URL" +conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL" # install requirements from pypi -conda run -n build_binary pip install torchmetrics==1.0.3 +conda run -n "${CONDA_ENV}" pip install torchmetrics==1.0.3 # install torchrec -conda run -n build_binary pip install torchrec --index-url "$PYTORCH_URL" +conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL" # Run small import test -conda run -n build_binary python -c "import torch; import fbgemm_gpu; import torchrec" +conda run -n "${CONDA_ENV}" python -c "import torch; import fbgemm_gpu; import torchrec" # check directory ls -R # check if cuda available -conda run -n build_binary python -c "import torch; print(torch.cuda.is_available())" +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())" # check cuda version -conda run -n build_binary python -c "import torch; print(torch.version.cuda)" +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)" # Finally run smoke test # python 3.11 needs torchx-nightly -conda run -n build_binary pip install torchx-nightly iopath +conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then - conda run -n build_binary torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py + conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py else - conda run -n build_binary torchx run -s local_cwd dist.ddp -j 1 --script test_installation.py -- --cpu_only + conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --script test_installation.py -- --cpu_only fi @@ -93,8 +113,8 @@ if [[ ${MATRIX_CHANNEL} != 'release' ]]; then exit 0 else # Check version matches only for release binaries - torchrec_version=$(conda run -n build_binary pip show torchrec | grep Version | cut -d' ' -f2) - fbgemm_version=$(conda run -n build_binary pip show fbgemm_gpu | grep Version | cut -d' ' -f2) + torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2) + fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2) if [ "$torchrec_version" != "$fbgemm_version" ]; then echo "Error: TorchRec package version does not match FBGEMM package version" @@ -102,22 +122,22 @@ else fi fi -conda create -y -n build_binary python="${MATRIX_PYTHON_VERSION}" +conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}" -conda run -n build_binary python --version +conda run -n "${CONDA_ENV}" python --version if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.4' ]]; then exit 0 fi echo "checking pypi release" -conda run -n build_binary pip install torch -conda run -n build_binary pip install fbgemm-gpu -conda run -n build_binary pip install torchrec +conda run -n "${CONDA_ENV}" pip install torch +conda run -n "${CONDA_ENV}" pip install fbgemm-gpu +conda run -n "${CONDA_ENV}" pip install torchrec # Check version matching again for PyPI -torchrec_version=$(conda run -n build_binary pip show torchrec | grep Version | cut -d' ' -f2) -fbgemm_version=$(conda run -n build_binary pip show fbgemm_gpu | grep Version | cut -d' ' -f2) +torchrec_version=$(conda run -n "${CONDA_ENV}" pip show torchrec | grep Version | cut -d' ' -f2) +fbgemm_version=$(conda run -n "${CONDA_ENV}" pip show fbgemm_gpu | grep Version | cut -d' ' -f2) if [ "$torchrec_version" != "$fbgemm_version" ]; then echo "Error: TorchRec package version does not match FBGEMM package version" @@ -128,13 +148,13 @@ fi ls -R # check if cuda available -conda run -n build_binary python -c "import torch; print(torch.cuda.is_available())" +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.cuda.is_available())" # check cuda version -conda run -n build_binary python -c "import torch; print(torch.version.cuda)" +conda run -n "${CONDA_ENV}" python -c "import torch; print(torch.version.cuda)" # python 3.11 needs torchx-nightly -conda run -n build_binary pip install torchx-nightly iopath +conda run -n "${CONDA_ENV}" pip install torchx-nightly iopath # Finally run smoke test -conda run -n build_binary torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py +conda run -n "${CONDA_ENV}" torchx run -s local_cwd dist.ddp -j 1 --gpu 2 --script test_installation.py From 69a3b528be2f759c4147c934adda676a56999be6 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 23 Jan 2025 20:04:11 +0000 Subject: [PATCH 02/11] Remove tensordict --- torchrec/sparse/tensor_dict.py | 45 ---------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 torchrec/sparse/tensor_dict.py diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py deleted file mode 100644 index 3f00d5275..000000000 --- a/torchrec/sparse/tensor_dict.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import List, Optional - -import torch -from tensordict import TensorDict - -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor - - -def maybe_td_to_kjt( - features: KeyedJaggedTensor, keys: Optional[List[str]] = None -) -> KeyedJaggedTensor: - if torch.jit.is_scripting(): - assert isinstance(features, KeyedJaggedTensor) - return features - if isinstance(features, TensorDict): - if keys is None: - keys = list(features.keys()) - values = torch.cat([features[key]._values for key in keys], dim=0) - lengths = torch.cat( - [ - ( - (features[key]._lengths) - if features[key]._lengths is not None - else torch.diff(features[key]._offsets) - ) - for key in keys - ], - dim=0, - ) - return KeyedJaggedTensor( - keys=keys, - values=values, - lengths=lengths, - ) - else: - return features From ebd64f3629cb22e5ab18aee9c229a83566b83d62 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 23 Jan 2025 20:04:46 +0000 Subject: [PATCH 03/11] Remove tensordict test --- torchrec/sparse/tests/test_tensor_dict.py | 72 ----------------------- 1 file changed, 72 deletions(-) delete mode 100644 torchrec/sparse/tests/test_tensor_dict.py diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py deleted file mode 100644 index 2fbcc0a66..000000000 --- a/torchrec/sparse/tests/test_tensor_dict.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -import unittest - -import torch -from hypothesis import given, settings, strategies as st, Verbosity -from tensordict import TensorDict -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt - - -class TestTensorDict(unittest.TestCase): - # pyre-ignore[56] - @given( - device_str=st.sampled_from( - ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) - ) - ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_kjt_input(self, device_str: str) -> None: - device = torch.device(device_str) - values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) - kjt = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2", "f3"], - values=values, - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7], device=device), - ) - features = maybe_td_to_kjt(kjt) - self.assertEqual(features, kjt) - - # pyre-ignore[56] - @given( - device_str=st.sampled_from( - ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) - ) - ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_td_kjt(self, device_str: str) -> None: - device = torch.device(device_str) - values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) - lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device) - data = { - "f2": torch.nested.nested_tensor_from_jagged( - torch.tensor([2, 3], device=device), - lengths=torch.tensor([1, 1], device=device), - ), - "f1": torch.nested.nested_tensor_from_jagged( - torch.arange(2, device=device), - offsets=torch.tensor([0, 2, 2], device=device), - ), - "f3": torch.nested.nested_tensor_from_jagged( - torch.tensor([2, 3, 4], device=device), - lengths=torch.tensor([1, 2], device=device), - ), - } - td = TensorDict( - data, # type: ignore[arg-type] - device=device, - batch_size=[2], - ) - - features = maybe_td_to_kjt(td, ["f1", "f2", "f3"]) # pyre-ignore[6] - torch.testing.assert_close(features.values(), values) - torch.testing.assert_close(features.lengths(), lengths) From e5e25650679a4c9d1655390ddf8be28169dd5981 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 23 Jan 2025 20:05:45 +0000 Subject: [PATCH 04/11] Revert "add NJT/TD support in test data generator (#2528)" This reverts commit e35119dfd5007bae6793a192f6b65f7da9b50e6f. --- install-requirements.txt | 1 - requirements.txt | 1 - ...enchmark_split_table_batched_embeddings.py | 9 +- .../distributed/benchmark/benchmark_utils.py | 5 +- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 123 +++++------------- .../distributed/tests/test_infer_shardings.py | 3 - .../tests/pipeline_benchmarks.py | 12 +- .../tests/test_train_pipelines.py | 6 +- .../keyed_jagged_tensor_benchmark_lib.py | 1 - 10 files changed, 44 insertions(+), 121 deletions(-) diff --git a/install-requirements.txt b/install-requirements.txt index ed3c6aced..ab2736d78 100644 --- a/install-requirements.txt +++ b/install-requirements.txt @@ -1,5 +1,4 @@ fbgemm-gpu -tensordict torchmetrics==1.0.3 tqdm pyre-extensions diff --git a/requirements.txt b/requirements.txt index 6d63107dd..b60a348f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ numpy pandas pyre-extensions scikit-build -tensordict torchmetrics==1.0.3 torchx tqdm diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index 8af1f9a46..b03e7b417 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -9,8 +9,6 @@ #!/usr/bin/env python3 -from typing import Dict, List - import click import torch @@ -84,10 +82,9 @@ def op_bench( ) def _func_to_benchmark( - kjts: List[Dict[str, KeyedJaggedTensor]], + kjt: KeyedJaggedTensor, model: torch.nn.Module, ) -> torch.Tensor: - kjt = kjts[0]["feature"] return model.forward(kjt.values(), kjt.offsets()) # breakpoint() # import fbvscode; fbvscode.set_trace() @@ -111,8 +108,8 @@ def _func_to_benchmark( result = benchmark_func( name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", - bench_inputs=[{"feature": inputs}], - prof_inputs=[{"feature": inputs}], + bench_inputs=inputs, # pyre-ignore + prof_inputs=inputs, # pyre-ignore num_benchmarks=10, num_profiles=10, profile_dir=".", diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 22af274d6..1878fdd1f 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -374,14 +374,11 @@ def get_inputs( if train: sparse_features_by_rank = [ - model_input.idlist_features - for model_input in model_input_by_rank - if isinstance(model_input.idlist_features, KeyedJaggedTensor) + model_input.idlist_features for model_input in model_input_by_rank ] inputs_batch.append(sparse_features_by_rank) else: sparse_features = model_input_by_rank[0].idlist_features - assert isinstance(sparse_features, KeyedJaggedTensor) inputs_batch.append([sparse_features]) # Transpose if train, as inputs_by_rank is currently in [B X R] format diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 478e01bb2..0604f1c29 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -264,7 +264,6 @@ def model_input_to_forward_args_kjt( Optional[torch.Tensor], ]: kjt = mi.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) return ( kjt._keys, kjt._values, @@ -292,8 +291,7 @@ def model_input_to_forward_args( ]: idlist_kjt = mi.idlist_features idscore_kjt = mi.idscore_features - assert isinstance(idlist_kjt, KeyedJaggedTensor) - assert isinstance(idscore_kjt, KeyedJaggedTensor) + assert idscore_kjt is not None return ( mi.float_features, idlist_kjt._keys, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 010abb459..3442b5dd3 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn -from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -47,8 +46,8 @@ @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: Union[KeyedJaggedTensor, TensorDict] - idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] + idlist_features: KeyedJaggedTensor + idscore_features: Optional[KeyedJaggedTensor] label: torch.Tensor @staticmethod @@ -77,13 +76,11 @@ def generate( randomize_indices: bool = True, device: Optional[torch.device] = None, max_feature_lengths: Optional[List[int]] = None, - input_type: str = "kjt", ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch and a list of local (multi-rank training) batches of world_size. """ - batch_size_by_rank = [batch_size] * world_size if variable_batch_size: batch_size_by_rank = [ @@ -202,26 +199,11 @@ def _validate_pooling_factor( ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - - if input_type == "kjt": - global_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, global_idlist_indices, global_idlist_lengths - ) - } - global_idlist_input = TensorDict(source=dict_of_nt) - else: - raise ValueError(f"For IdList features, unknown input type {input_type}") + global_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] @@ -263,25 +245,16 @@ def _validate_pooling_factor( global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - - if input_type == "kjt": - global_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), - ) - if global_idscore_indices - else None + global_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), ) - elif input_type == "td": - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - global_idscore_input = None - else: - raise ValueError(f"For weighted features, unknown input type {input_type}") + if global_idscore_indices + else None + ) if randomize_indices: global_float = torch.rand( @@ -330,48 +303,27 @@ def _validate_pooling_factor( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - if input_type == "kjt": - local_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) - - local_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), - ) - if local_idscore_indices - else None - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, local_idlist_indices, local_idlist_lengths - ) - } - local_idlist_input = TensorDict(source=dict_of_nt) - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - local_idscore_input = None + local_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) - else: - raise ValueError( - f"For weighted features, unknown input type {input_type}" + local_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), ) + if local_idscore_indices + else None + ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_input, - idscore_features=local_idscore_input, + idlist_features=local_idlist_kjt, + idscore_features=local_idscore_kjt, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -379,8 +331,8 @@ def _validate_pooling_factor( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_input, - idscore_features=global_idscore_input, + idlist_features=global_idlist_kjt, + idscore_features=global_idscore_kjt, label=global_label, ), local_inputs, @@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - if isinstance(self.idlist_features, KeyedJaggedTensor): - self.idlist_features.record_stream(stream) - if isinstance(self.idscore_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if self.idscore_features is not None: self.idscore_features.record_stream(stream) self.label.record_stream(stream) @@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput: ) # stride will be same but features will be joined - assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) - assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.concat( [modified_input.idlist_features, self._extra_input.idlist_features] ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 83b4649ee..c7c6ef180 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1987,7 +1987,6 @@ def test_sharded_quant_fp_ebc_tw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device @@ -2167,7 +2166,6 @@ def test_sharded_quant_mc_ec_rw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = None inputs.append( @@ -2303,7 +2301,6 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: ) inputs = [] kjt = model_inputs[0].idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..538264c04 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -75,11 +75,6 @@ def _gen_pipelines( default=100, help="Total number of sparse embeddings to be used.", ) -@click.option( - "--ratio_features_weighted", - default=0.4, - help="percentage of features weighted vs unweighted", -) @click.option( "--dim_emb", type=int, @@ -137,7 +132,6 @@ def _gen_pipelines( def main( world_size: int, n_features: int, - ratio_features_weighted: float, dim_emb: int, n_batches: int, batch_size: int, @@ -155,9 +149,8 @@ def main( os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) - num_weighted_features = int(n_features * ratio_features_weighted) - num_features = n_features - num_weighted_features - + num_features = n_features // 2 + num_weighted_features = n_features // 2 tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 1000, @@ -264,7 +257,6 @@ def _generate_data( world_size=world_size, num_float_features=num_float_features, pooling_avg=pooling_factor, - input_type=input_type, )[1] for i in range(num_batches) ] diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index bf708b1f5..9c39b5384 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -306,11 +306,7 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: # `parameters`. optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [ - i.idlist_features - for i in local_model_inputs - if isinstance(i.idlist_features, KeyedJaggedTensor) - ] + data = [i.idlist_features for i in local_model_inputs] dataloader = iter(data) pipeline = TrainPipelinePT2( model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 1c409fcf2..235495494 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -169,7 +169,6 @@ def generate_kjt( randomize_indices=True, device=device, )[0] - assert isinstance(global_input.idlist_features, KeyedJaggedTensor) return global_input.idlist_features From 3441ac333f69c7acfb175a0f0d260c335d91831e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 23 Jan 2025 20:21:34 +0000 Subject: [PATCH 05/11] Revert "Revert "add NJT/TD support in test data generator (#2528)"" This reverts commit e5e25650679a4c9d1655390ddf8be28169dd5981. --- install-requirements.txt | 1 + requirements.txt | 1 + ...enchmark_split_table_batched_embeddings.py | 9 +- .../distributed/benchmark/benchmark_utils.py | 5 +- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 123 +++++++++++++----- .../distributed/tests/test_infer_shardings.py | 3 + .../tests/pipeline_benchmarks.py | 12 +- .../tests/test_train_pipelines.py | 6 +- .../keyed_jagged_tensor_benchmark_lib.py | 1 + 10 files changed, 121 insertions(+), 44 deletions(-) diff --git a/install-requirements.txt b/install-requirements.txt index ab2736d78..ed3c6aced 100644 --- a/install-requirements.txt +++ b/install-requirements.txt @@ -1,4 +1,5 @@ fbgemm-gpu +tensordict torchmetrics==1.0.3 tqdm pyre-extensions diff --git a/requirements.txt b/requirements.txt index b60a348f4..6d63107dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ numpy pandas pyre-extensions scikit-build +tensordict torchmetrics==1.0.3 torchx tqdm diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index b03e7b417..8af1f9a46 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -9,6 +9,8 @@ #!/usr/bin/env python3 +from typing import Dict, List + import click import torch @@ -82,9 +84,10 @@ def op_bench( ) def _func_to_benchmark( - kjt: KeyedJaggedTensor, + kjts: List[Dict[str, KeyedJaggedTensor]], model: torch.nn.Module, ) -> torch.Tensor: + kjt = kjts[0]["feature"] return model.forward(kjt.values(), kjt.offsets()) # breakpoint() # import fbvscode; fbvscode.set_trace() @@ -108,8 +111,8 @@ def _func_to_benchmark( result = benchmark_func( name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", - bench_inputs=inputs, # pyre-ignore - prof_inputs=inputs, # pyre-ignore + bench_inputs=[{"feature": inputs}], + prof_inputs=[{"feature": inputs}], num_benchmarks=10, num_profiles=10, profile_dir=".", diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 1878fdd1f..22af274d6 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -374,11 +374,14 @@ def get_inputs( if train: sparse_features_by_rank = [ - model_input.idlist_features for model_input in model_input_by_rank + model_input.idlist_features + for model_input in model_input_by_rank + if isinstance(model_input.idlist_features, KeyedJaggedTensor) ] inputs_batch.append(sparse_features_by_rank) else: sparse_features = model_input_by_rank[0].idlist_features + assert isinstance(sparse_features, KeyedJaggedTensor) inputs_batch.append([sparse_features]) # Transpose if train, as inputs_by_rank is currently in [B X R] format diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 0604f1c29..478e01bb2 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -264,6 +264,7 @@ def model_input_to_forward_args_kjt( Optional[torch.Tensor], ]: kjt = mi.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) return ( kjt._keys, kjt._values, @@ -291,7 +292,8 @@ def model_input_to_forward_args( ]: idlist_kjt = mi.idlist_features idscore_kjt = mi.idscore_features - assert idscore_kjt is not None + assert isinstance(idlist_kjt, KeyedJaggedTensor) + assert isinstance(idscore_kjt, KeyedJaggedTensor) return ( mi.float_features, idlist_kjt._keys, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 3442b5dd3..010abb459 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -46,8 +47,8 @@ @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: KeyedJaggedTensor - idscore_features: Optional[KeyedJaggedTensor] + idlist_features: Union[KeyedJaggedTensor, TensorDict] + idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] label: torch.Tensor @staticmethod @@ -76,11 +77,13 @@ def generate( randomize_indices: bool = True, device: Optional[torch.device] = None, max_feature_lengths: Optional[List[int]] = None, + input_type: str = "kjt", ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch and a list of local (multi-rank training) batches of world_size. """ + batch_size_by_rank = [batch_size] * world_size if variable_batch_size: batch_size_by_rank = [ @@ -199,11 +202,26 @@ def _validate_pooling_factor( ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - global_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) + + if input_type == "kjt": + global_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, global_idlist_indices, global_idlist_lengths + ) + } + global_idlist_input = TensorDict(source=dict_of_nt) + else: + raise ValueError(f"For IdList features, unknown input type {input_type}") for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] @@ -245,16 +263,25 @@ def _validate_pooling_factor( global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - global_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), + + if input_type == "kjt": + global_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), + ) + if global_idscore_indices + else None ) - if global_idscore_indices - else None - ) + elif input_type == "td": + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + global_idscore_input = None + else: + raise ValueError(f"For weighted features, unknown input type {input_type}") if randomize_indices: global_float = torch.rand( @@ -303,27 +330,48 @@ def _validate_pooling_factor( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - local_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) + if input_type == "kjt": + local_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) - local_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), + local_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), + ) + if local_idscore_indices + else None + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, local_idlist_indices, local_idlist_lengths + ) + } + local_idlist_input = TensorDict(source=dict_of_nt) + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + local_idscore_input = None + + else: + raise ValueError( + f"For weighted features, unknown input type {input_type}" ) - if local_idscore_indices - else None - ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_kjt, - idscore_features=local_idscore_kjt, + idlist_features=local_idlist_input, + idscore_features=local_idscore_input, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -331,8 +379,8 @@ def _validate_pooling_factor( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_kjt, - idscore_features=global_idscore_kjt, + idlist_features=global_idlist_input, + idscore_features=global_idscore_input, label=global_label, ), local_inputs, @@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - self.idlist_features.record_stream(stream) - if self.idscore_features is not None: + if isinstance(self.idlist_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if isinstance(self.idscore_features, KeyedJaggedTensor): self.idscore_features.record_stream(stream) self.label.record_stream(stream) @@ -1831,6 +1880,8 @@ def forward(self, input: ModelInput) -> ModelInput: ) # stride will be same but features will be joined + assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) + assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.concat( [modified_input.idlist_features, self._extra_input.idlist_features] ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index c7c6ef180..83b4649ee 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1987,6 +1987,7 @@ def test_sharded_quant_fp_ebc_tw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device @@ -2166,6 +2167,7 @@ def test_sharded_quant_mc_ec_rw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = None inputs.append( @@ -2301,6 +2303,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: ) inputs = [] kjt = model_inputs[0].idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index 538264c04..e8dc5eccb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -75,6 +75,11 @@ def _gen_pipelines( default=100, help="Total number of sparse embeddings to be used.", ) +@click.option( + "--ratio_features_weighted", + default=0.4, + help="percentage of features weighted vs unweighted", +) @click.option( "--dim_emb", type=int, @@ -132,6 +137,7 @@ def _gen_pipelines( def main( world_size: int, n_features: int, + ratio_features_weighted: float, dim_emb: int, n_batches: int, batch_size: int, @@ -149,8 +155,9 @@ def main( os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) - num_features = n_features // 2 - num_weighted_features = n_features // 2 + num_weighted_features = int(n_features * ratio_features_weighted) + num_features = n_features - num_weighted_features + tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 1000, @@ -257,6 +264,7 @@ def _generate_data( world_size=world_size, num_float_features=num_float_features, pooling_avg=pooling_factor, + input_type=input_type, )[1] for i in range(num_batches) ] diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 9c39b5384..bf708b1f5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -306,7 +306,11 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: # `parameters`. optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [i.idlist_features for i in local_model_inputs] + data = [ + i.idlist_features + for i in local_model_inputs + if isinstance(i.idlist_features, KeyedJaggedTensor) + ] dataloader = iter(data) pipeline = TrainPipelinePT2( model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 235495494..1c409fcf2 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -169,6 +169,7 @@ def generate_kjt( randomize_indices=True, device=device, )[0] + assert isinstance(global_input.idlist_features, KeyedJaggedTensor) return global_input.idlist_features From 0ce7cc62d155e9ec5fbae0120421733a76cd645c Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 24 Jan 2025 17:16:40 +0000 Subject: [PATCH 06/11] Remove TensorDict from requirements --- install-requirements.txt | 1 - requirements.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/install-requirements.txt b/install-requirements.txt index ed3c6aced..ab2736d78 100644 --- a/install-requirements.txt +++ b/install-requirements.txt @@ -1,5 +1,4 @@ fbgemm-gpu -tensordict torchmetrics==1.0.3 tqdm pyre-extensions diff --git a/requirements.txt b/requirements.txt index 6d63107dd..b60a348f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ numpy pandas pyre-extensions scikit-build -tensordict torchmetrics==1.0.3 torchx tqdm From e6d5560f86c6a5e469d09a57f6b26da441261def Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 24 Jan 2025 18:47:55 +0000 Subject: [PATCH 07/11] Revert "Revert "Revert "add NJT/TD support in test data generator (#2528)""" This reverts commit 3441ac333f69c7acfb175a0f0d260c335d91831e. --- ...enchmark_split_table_batched_embeddings.py | 9 +- .../distributed/benchmark/benchmark_utils.py | 5 +- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 123 +++++------------- .../distributed/tests/test_infer_shardings.py | 3 - .../tests/pipeline_benchmarks.py | 12 +- .../tests/test_train_pipelines.py | 6 +- .../keyed_jagged_tensor_benchmark_lib.py | 1 - 8 files changed, 44 insertions(+), 119 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index 8af1f9a46..b03e7b417 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -9,8 +9,6 @@ #!/usr/bin/env python3 -from typing import Dict, List - import click import torch @@ -84,10 +82,9 @@ def op_bench( ) def _func_to_benchmark( - kjts: List[Dict[str, KeyedJaggedTensor]], + kjt: KeyedJaggedTensor, model: torch.nn.Module, ) -> torch.Tensor: - kjt = kjts[0]["feature"] return model.forward(kjt.values(), kjt.offsets()) # breakpoint() # import fbvscode; fbvscode.set_trace() @@ -111,8 +108,8 @@ def _func_to_benchmark( result = benchmark_func( name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", - bench_inputs=[{"feature": inputs}], - prof_inputs=[{"feature": inputs}], + bench_inputs=inputs, # pyre-ignore + prof_inputs=inputs, # pyre-ignore num_benchmarks=10, num_profiles=10, profile_dir=".", diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 22af274d6..1878fdd1f 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -374,14 +374,11 @@ def get_inputs( if train: sparse_features_by_rank = [ - model_input.idlist_features - for model_input in model_input_by_rank - if isinstance(model_input.idlist_features, KeyedJaggedTensor) + model_input.idlist_features for model_input in model_input_by_rank ] inputs_batch.append(sparse_features_by_rank) else: sparse_features = model_input_by_rank[0].idlist_features - assert isinstance(sparse_features, KeyedJaggedTensor) inputs_batch.append([sparse_features]) # Transpose if train, as inputs_by_rank is currently in [B X R] format diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 478e01bb2..0604f1c29 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -264,7 +264,6 @@ def model_input_to_forward_args_kjt( Optional[torch.Tensor], ]: kjt = mi.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) return ( kjt._keys, kjt._values, @@ -292,8 +291,7 @@ def model_input_to_forward_args( ]: idlist_kjt = mi.idlist_features idscore_kjt = mi.idscore_features - assert isinstance(idlist_kjt, KeyedJaggedTensor) - assert isinstance(idscore_kjt, KeyedJaggedTensor) + assert idscore_kjt is not None return ( mi.float_features, idlist_kjt._keys, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 010abb459..3442b5dd3 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn -from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -47,8 +46,8 @@ @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: Union[KeyedJaggedTensor, TensorDict] - idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] + idlist_features: KeyedJaggedTensor + idscore_features: Optional[KeyedJaggedTensor] label: torch.Tensor @staticmethod @@ -77,13 +76,11 @@ def generate( randomize_indices: bool = True, device: Optional[torch.device] = None, max_feature_lengths: Optional[List[int]] = None, - input_type: str = "kjt", ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch and a list of local (multi-rank training) batches of world_size. """ - batch_size_by_rank = [batch_size] * world_size if variable_batch_size: batch_size_by_rank = [ @@ -202,26 +199,11 @@ def _validate_pooling_factor( ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - - if input_type == "kjt": - global_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, global_idlist_indices, global_idlist_lengths - ) - } - global_idlist_input = TensorDict(source=dict_of_nt) - else: - raise ValueError(f"For IdList features, unknown input type {input_type}") + global_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] @@ -263,25 +245,16 @@ def _validate_pooling_factor( global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - - if input_type == "kjt": - global_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), - ) - if global_idscore_indices - else None + global_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), ) - elif input_type == "td": - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - global_idscore_input = None - else: - raise ValueError(f"For weighted features, unknown input type {input_type}") + if global_idscore_indices + else None + ) if randomize_indices: global_float = torch.rand( @@ -330,48 +303,27 @@ def _validate_pooling_factor( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - if input_type == "kjt": - local_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) - - local_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), - ) - if local_idscore_indices - else None - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, local_idlist_indices, local_idlist_lengths - ) - } - local_idlist_input = TensorDict(source=dict_of_nt) - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - local_idscore_input = None + local_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) - else: - raise ValueError( - f"For weighted features, unknown input type {input_type}" + local_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), ) + if local_idscore_indices + else None + ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_input, - idscore_features=local_idscore_input, + idlist_features=local_idlist_kjt, + idscore_features=local_idscore_kjt, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -379,8 +331,8 @@ def _validate_pooling_factor( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_input, - idscore_features=global_idscore_input, + idlist_features=global_idlist_kjt, + idscore_features=global_idscore_kjt, label=global_label, ), local_inputs, @@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - if isinstance(self.idlist_features, KeyedJaggedTensor): - self.idlist_features.record_stream(stream) - if isinstance(self.idscore_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if self.idscore_features is not None: self.idscore_features.record_stream(stream) self.label.record_stream(stream) @@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput: ) # stride will be same but features will be joined - assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) - assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.concat( [modified_input.idlist_features, self._extra_input.idlist_features] ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 83b4649ee..c7c6ef180 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1987,7 +1987,6 @@ def test_sharded_quant_fp_ebc_tw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device @@ -2167,7 +2166,6 @@ def test_sharded_quant_mc_ec_rw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = None inputs.append( @@ -2303,7 +2301,6 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: ) inputs = [] kjt = model_inputs[0].idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..538264c04 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -75,11 +75,6 @@ def _gen_pipelines( default=100, help="Total number of sparse embeddings to be used.", ) -@click.option( - "--ratio_features_weighted", - default=0.4, - help="percentage of features weighted vs unweighted", -) @click.option( "--dim_emb", type=int, @@ -137,7 +132,6 @@ def _gen_pipelines( def main( world_size: int, n_features: int, - ratio_features_weighted: float, dim_emb: int, n_batches: int, batch_size: int, @@ -155,9 +149,8 @@ def main( os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) - num_weighted_features = int(n_features * ratio_features_weighted) - num_features = n_features - num_weighted_features - + num_features = n_features // 2 + num_weighted_features = n_features // 2 tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 1000, @@ -264,7 +257,6 @@ def _generate_data( world_size=world_size, num_float_features=num_float_features, pooling_avg=pooling_factor, - input_type=input_type, )[1] for i in range(num_batches) ] diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index bf708b1f5..9c39b5384 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -306,11 +306,7 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: # `parameters`. optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [ - i.idlist_features - for i in local_model_inputs - if isinstance(i.idlist_features, KeyedJaggedTensor) - ] + data = [i.idlist_features for i in local_model_inputs] dataloader = iter(data) pipeline = TrainPipelinePT2( model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 1c409fcf2..235495494 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -169,7 +169,6 @@ def generate_kjt( randomize_indices=True, device=device, )[0] - assert isinstance(global_input.idlist_features, KeyedJaggedTensor) return global_input.idlist_features From c5bb731aea3d6dc18ea933edd5bd13f828521e73 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 24 Jan 2025 11:37:33 -0800 Subject: [PATCH 08/11] [fbgemm_gpu] Fix validate_binaries.sh for the CPU variant of FBGEMM_GPU - Fix validate_binaries.sh for the CPU variant of FBGEMM_GPU --- .github/scripts/validate_binaries.sh | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh index 6750b5b74..f9958a208 100755 --- a/.github/scripts/validate_binaries.sh +++ b/.github/scripts/validate_binaries.sh @@ -58,15 +58,13 @@ echo "CONDA_ENV: ${CONDA_ENV}" # shellcheck disable=SC2155 export CONDA_PREFIX=$(conda run -n "${CONDA_ENV}" printenv CONDA_PREFIX) -find / -name *cuda* - -if [[ $CUDA_VERSION = cu* ]]; then - # Setting LD_LIBRARY_PATH fixes the runtime error with fbgemm_gpu not - # being able to locate libnvrtc.so - echo "[NOVA] Setting LD_LIBRARY_PATH ..." - conda env config vars set -n ${CONDA_ENV} \ - LD_LIBRARY_PATH="/usr/local/lib:/usr/lib64:${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" -fi + +# Set LD_LIBRARY_PATH to fix the runtime error with fbgemm_gpu not +# being able to locate libnvrtc.so +# NOTE: The order of the entries in LD_LIBRARY_PATH matters +echo "[NOVA] Setting LD_LIBRARY_PATH ..." +conda env config vars set -n ${CONDA_ENV} \ + LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:/usr/local/lib:/usr/lib64:${LD_LIBRARY_PATH}" # install pytorch From d4b896f9f5e05f8908d0d7c9c811384a5d1d3bdd Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 22 Jan 2025 16:37:24 -0800 Subject: [PATCH 09/11] Teach is_signature_compatible() to dig into similar annotations (#2693) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2693 D68450007 updated some annotations in pytorch. This function wasn't correctly evaluating `typing.Dict[X, Y]` and `dict[X, Y]` as the equivalent. Reviewed By: izaitsevfb Differential Revision: D68475380 fbshipit-source-id: 3b71ab41f95e6c20986ebe6fbf6f9cbe3b3d58f9 --- torchrec/schema/utils.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py index 0f9b897cb..b4f8a6075 100644 --- a/torchrec/schema/utils.py +++ b/torchrec/schema/utils.py @@ -8,6 +8,32 @@ # pyre-strict import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: object, curr: object) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True def is_signature_compatible( @@ -84,6 +110,8 @@ def is_signature_compatible( return False # TODO: Account for Union Types? - if current_signature.return_annotation != previous_signature.return_annotation: + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): return False return True From cac04c80565086a81149b612fc9a68424f65e950 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 30 Jan 2025 17:53:36 +0000 Subject: [PATCH 10/11] Fix setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 34987b905..9582f8446 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,6 @@ def main(argv: List[str]) -> None: zip_safe=False, # PyPI package information. classifiers=[ - "Development Status :: 5 - Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", From 2c5f6eef88ab689ec5cbd0764594e2b8c0e36163 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 30 Jan 2025 18:23:21 +0000 Subject: [PATCH 11/11] Update version --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index ae7fb2d41..9084fa2f7 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.1.0a0 +1.1.0