You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
# context
* original diff D74366343 broke cogwheel test and was reverted
* the error stack P1844048578 is shown below:
```
File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__
data = request.wait()
File "/dev/torchrec/distributed/types.py", line 334, in wait
ret: W = self._wait_impl()
File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl
kjts.append(w.wait())
File "/dev/torchrec/distributed/types.py", line 334, in wait
ret: W = self._wait_impl()
File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl
return type(self._input).dist_init(
File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init
return kjt.sync()
File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync
self.length_per_key()
File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key
_length_per_key = _maybe_compute_length_per_key(
File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key
_length_per_key_from_stride_per_key(lengths, stride_per_key)
File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key
if _use_segment_sum_csr(stride_per_key):
File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
ZeroDivisionError: division by zero
```
* the complaint is `stride_per_key` is an empty list, which comes from the following function call:
```
stride_per_key = _maybe_compute_stride_per_key(
self._stride_per_key,
self._stride_per_key_per_rank,
self.stride(),
self._keys,
)
```
* the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2`
```
def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[torch.IntTensor],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if stride_per_key is not None:
return stride_per_key
elif stride_per_key_per_rank is not None:
if stride_per_key_per_rank.dim() != 2:
# after permute the kjt could be empty
return []
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
pt2_checks_all_is_size(rt)
return rt
elif stride is not None:
return [stride] * len(keys)
else:
return None
```
# the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`:
* baseline
```
if stagger > 1:
stride_per_key_per_rank_stagger: List[List[int]] = []
local_world_size = num_workers // stagger
for i in range(len(keys)):
stride_per_rank_stagger: List[int] = []
for j in range(local_world_size):
stride_per_rank_stagger.extend(
stride_per_key_per_rank[i][j::local_world_size]
)
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
stride_per_key_per_rank = stride_per_key_per_rank_stagger
```
* D76875546 (correct, this diff)
```
if stagger > 1:
indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1)
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```
* D74366343 (incorrect, reverted)
```
if stagger > 1:
local_world_size = num_workers // stagger
indices = [
list(range(i, num_workers, local_world_size))
for i in range(local_world_size)
]
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```
Differential Revision: D76875546
0 commit comments