Skip to content

Commit 1c17a39

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix validate nightly binaries (#3117)
Summary: # context * original diff D74366343 broke cogwheel test and was reverted * the error stack P1844048578 is shown below: ``` File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__ data = request.wait() File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl kjts.append(w.wait()) File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl return type(self._input).dist_init( File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init return kjt.sync() File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync self.length_per_key() File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key _length_per_key = _maybe_compute_length_per_key( File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key _length_per_key_from_stride_per_key(lengths, stride_per_key) File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key if _use_segment_sum_csr(stride_per_key): File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr elements_per_segment = sum(stride_per_key) / len(stride_per_key) ZeroDivisionError: division by zero ``` * the complaint is `stride_per_key` is an empty list, which comes from the following function call: ``` stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, self._stride_per_key_per_rank, self.stride(), self._keys, ) ``` * the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2` ``` def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: if stride_per_key_per_rank.dim() != 2: # after permute the kjt could be empty return [] rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() if not torch.jit.is_scripting() and is_torchdynamo_compiling(): pt2_checks_all_is_size(rt) return rt elif stride is not None: return [stride] * len(keys) else: return None ``` # the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`: * baseline ``` if stagger > 1: stride_per_key_per_rank_stagger: List[List[int]] = [] local_world_size = num_workers // stagger for i in range(len(keys)): stride_per_rank_stagger: List[int] = [] for j in range(local_world_size): stride_per_rank_stagger.extend( stride_per_key_per_rank[i][j::local_world_size] ) stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) stride_per_key_per_rank = stride_per_key_per_rank_stagger ``` * D76875546 (correct, this diff) ``` if stagger > 1: indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1) stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` * D74366343 (incorrect, reverted) ``` if stagger > 1: local_world_size = num_workers // stagger indices = [ list(range(i, num_workers, local_world_size)) for i in range(local_world_size) ] stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` Differential Revision: D76875546
1 parent 57abf4e commit 1c17a39

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

.github/scripts/validate_binaries.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,20 @@ conda env config vars set -n ${CONDA_ENV} \
7373
# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}"
7474
# fi
7575

76-
conda run -n "${CONDA_ENV}" pip install importlib-metadata
77-
7876
conda run -n "${CONDA_ENV}" pip install torch --index-url "$PYTORCH_URL"
7977

8078
# install fbgemm
8179
conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL"
8280

83-
# install requirements from pypi
84-
conda run -n "${CONDA_ENV}" pip install torchmetrics==1.0.3
85-
8681
# install tensordict from pypi
8782
conda run -n "${CONDA_ENV}" pip install tensordict==0.8.1
8883

8984
# install torchrec
9085
conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL"
9186

87+
# install other requirements
88+
conda run -n "${CONDA_ENV}" pip install -r requirements.txt
89+
9290
# Run small import test
9391
conda run -n "${CONDA_ENV}" python -c "import torch; import fbgemm_gpu; import torchrec"
9492

.github/workflows/validate-nightly-binaries.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ on:
1111
branches:
1212
- main
1313
paths:
14-
- .github/workflows/validate-nightly-binaries.yml
15-
- .github/workflows/validate-binaries.yml
16-
- .github/scripts/validate-binaries.sh
14+
- ./.github/workflows/validate-nightly-binaries.yml
15+
- ./.github/workflows/validate-binaries.yml
16+
- ./.github/scripts/validate-binaries.sh
1717
pull_request:
1818
paths:
19-
- .github/workflows/validate-nightly-binaries.yml
20-
- .github/workflows/validate-binaries.yml
21-
- .github/scripts/validate-binaries.sh
19+
- ./.github/workflows/validate-nightly-binaries.yml
20+
- ./.github/workflows/validate-binaries.yml
21+
- ./.github/scripts/validate-binaries.sh
2222
jobs:
2323
nightly:
2424
uses: ./.github/workflows/validate-binaries.yml

0 commit comments

Comments
 (0)