Skip to content

Commit 7c87718

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D74366343
Summary: This diff reverts D74366343 (The context such as a Sandcastle job, Task, SEV, etc. was not provided.) Depends on D74366343 Reviewed By: jd7-tr Differential Revision: D76870690 fbshipit-source-id: 8c08ef3272cdcf5345c3d98c0bccf2e02d5846a6
1 parent cc92389 commit 7c87718

File tree

3 files changed

+47
-86
lines changed

3 files changed

+47
-86
lines changed

torchrec/pt2/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def kjt_for_pt2_tracing(
5454
values=values,
5555
lengths=lengths,
5656
weights=kjt.weights_or_none(),
57-
stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"),
57+
stride_per_key_per_rank=[[stride]] * n,
5858
inverse_indices=(kjt.keys(), inverse_indices_tensor),
5959
)
6060

@@ -85,7 +85,7 @@ def kjt_for_pt2_tracing(
8585
lengths=lengths,
8686
weights=weights,
8787
stride=stride if not is_vb else None,
88-
stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None,
88+
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
8989
inverse_indices=inverse_indices,
9090
)
9191

torchrec/schema/api_tests/test_jagged_tensor_schema.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import unittest
12-
from typing import Dict, List, Optional, Tuple, Union
12+
from typing import Dict, List, Optional, Tuple
1313

1414
import torch
1515
from torchrec.schema.utils import is_signature_compatible
@@ -112,9 +112,7 @@ def __init__(
112112
lengths: Optional[torch.Tensor] = None,
113113
offsets: Optional[torch.Tensor] = None,
114114
stride: Optional[int] = None,
115-
stride_per_key_per_rank: Optional[
116-
Union[List[List[int]], torch.IntTensor]
117-
] = None,
115+
stride_per_key_per_rank: Optional[List[List[int]]] = None,
118116
# Below exposed to ensure torch.script-able
119117
stride_per_key: Optional[List[int]] = None,
120118
length_per_key: Optional[List[int]] = None,

torchrec/sparse/jagged_tensor.py

Lines changed: 43 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,15 +1098,13 @@ def _maybe_compute_stride_kjt(
10981098
stride: Optional[int],
10991099
lengths: Optional[torch.Tensor],
11001100
offsets: Optional[torch.Tensor],
1101-
stride_per_key_per_rank: Optional[torch.IntTensor],
1101+
stride_per_key_per_rank: Optional[List[List[int]]],
11021102
) -> int:
11031103
if stride is None:
11041104
if len(keys) == 0:
11051105
stride = 0
1106-
elif (
1107-
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1108-
):
1109-
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
1106+
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
1107+
stride = max([sum(s) for s in stride_per_key_per_rank])
11101108
elif offsets is not None and offsets.numel() > 0:
11111109
stride = (offsets.numel() - 1) // len(keys)
11121110
elif lengths is not None:
@@ -1485,8 +1483,8 @@ def _strides_from_kjt(
14851483
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14861484
# empty like function fx wrapped, also avoids device hardcoding
14871485
stride, stride_per_key_per_rank = (
1488-
(None, kjt._stride_per_key_per_rank)
1489-
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
1486+
(None, kjt.stride_per_key_per_rank())
1487+
if kjt.variable_stride_per_key()
14901488
else (kjt.stride(), None)
14911489
)
14921490

@@ -1672,20 +1670,14 @@ def _maybe_compute_lengths_offset_per_key(
16721670

16731671
def _maybe_compute_stride_per_key(
16741672
stride_per_key: Optional[List[int]],
1675-
stride_per_key_per_rank: Optional[torch.IntTensor],
1673+
stride_per_key_per_rank: Optional[List[List[int]]],
16761674
stride: Optional[int],
16771675
keys: List[str],
16781676
) -> Optional[List[int]]:
16791677
if stride_per_key is not None:
16801678
return stride_per_key
16811679
elif stride_per_key_per_rank is not None:
1682-
if stride_per_key_per_rank.dim() != 2:
1683-
# after permute the kjt could be empty
1684-
return []
1685-
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
1686-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1687-
pt2_checks_all_is_size(rt)
1688-
return rt
1680+
return [sum(s) for s in stride_per_key_per_rank]
16891681
elif stride is not None:
16901682
return [stride] * len(keys)
16911683
else:
@@ -1776,9 +1768,7 @@ def __init__(
17761768
lengths: Optional[torch.Tensor] = None,
17771769
offsets: Optional[torch.Tensor] = None,
17781770
stride: Optional[int] = None,
1779-
stride_per_key_per_rank: Optional[
1780-
Union[torch.IntTensor, List[List[int]]]
1781-
] = None,
1771+
stride_per_key_per_rank: Optional[List[List[int]]] = None,
17821772
# Below exposed to ensure torch.script-able
17831773
stride_per_key: Optional[List[int]] = None,
17841774
length_per_key: Optional[List[int]] = None,
@@ -1800,14 +1790,8 @@ def __init__(
18001790
self._lengths: Optional[torch.Tensor] = lengths
18011791
self._offsets: Optional[torch.Tensor] = offsets
18021792
self._stride: Optional[int] = stride
1803-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1804-
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1805-
# does not take List[List[int]]
1806-
assert not isinstance(stride_per_key_per_rank, list)
1807-
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
1808-
torch.IntTensor(stride_per_key_per_rank, device="cpu")
1809-
if isinstance(stride_per_key_per_rank, list)
1810-
else stride_per_key_per_rank
1793+
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1794+
stride_per_key_per_rank
18111795
)
18121796
self._stride_per_key: Optional[List[int]] = stride_per_key
18131797
self._length_per_key: Optional[List[int]] = length_per_key
@@ -1818,8 +1802,6 @@ def __init__(
18181802
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
18191803
inverse_indices
18201804
)
1821-
# this is only needed for torch.compile case
1822-
self._pt2_stride_per_key_per_rank: Optional[List[List[int]]] = None
18231805

18241806
# legacy attribute, for backward compatabilibity
18251807
self._variable_stride_per_key: Optional[bool] = None
@@ -1835,6 +1817,10 @@ def _init_pt2_checks(self) -> None:
18351817
return
18361818
if self._stride_per_key is not None:
18371819
pt2_checks_all_is_size(self._stride_per_key)
1820+
if self._stride_per_key_per_rank is not None:
1821+
# pyre-ignore [16]
1822+
for s in self._stride_per_key_per_rank:
1823+
pt2_checks_all_is_size(s)
18381824

18391825
@staticmethod
18401826
def from_offsets_sync(
@@ -2044,7 +2030,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20442030
kjt_stride, kjt_stride_per_key_per_rank = (
20452031
(stride_per_key[0], None)
20462032
if all(s == stride_per_key[0] for s in stride_per_key)
2047-
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
2033+
else (None, [[stride] for stride in stride_per_key])
20482034
)
20492035
kjt = KeyedJaggedTensor(
20502036
keys=kjt_keys,
@@ -2209,32 +2195,12 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
22092195
Returns:
22102196
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
22112197
"""
2212-
# making a local reference to the class variable to make jit.script behave
2213-
_stride_per_key_per_rank = self._stride_per_key_per_rank
2214-
if (
2215-
not torch.jit.is_scripting()
2216-
and is_torchdynamo_compiling()
2217-
and _stride_per_key_per_rank is not None
2218-
):
2219-
if self._pt2_stride_per_key_per_rank is not None:
2220-
return self._pt2_stride_per_key_per_rank
2221-
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
2222-
for stride_per_rank in stride_per_key_per_rank:
2223-
pt2_checks_all_is_size(stride_per_rank)
2224-
self._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2225-
return stride_per_key_per_rank
2226-
return (
2227-
[]
2228-
if _stride_per_key_per_rank is None
2229-
else _stride_per_key_per_rank.tolist()
2230-
)
2198+
stride_per_key_per_rank = self._stride_per_key_per_rank
2199+
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
22312200

22322201
def variable_stride_per_key(self) -> bool:
22332202
"""
22342203
Returns whether the KeyedJaggedTensor has variable stride per key.
2235-
NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2236-
is not `None`. It might be assigned to False externally/intentionally, usually the
2237-
`self._stride_per_key_per_rank` is trivial.
22382204
22392205
Returns:
22402206
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2379,16 +2345,13 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23792345
start_offset = 0
23802346
_length_per_key = self.length_per_key()
23812347
_offset_per_key = self.offset_per_key()
2382-
# use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2383-
_stride_per_key_per_rank = self._stride_per_key_per_rank
23842348
for segment in segments:
23852349
end = start + segment
23862350
end_offset = _offset_per_key[end]
23872351
keys: List[str] = self._keys[start:end]
23882352
stride_per_key_per_rank = (
2389-
_stride_per_key_per_rank[start:end, :]
2353+
self.stride_per_key_per_rank()[start:end]
23902354
if self.variable_stride_per_key()
2391-
and _stride_per_key_per_rank is not None
23922355
else None
23932356
)
23942357
if segment == len(self._keys):
@@ -2536,24 +2499,17 @@ def permute(
25362499

25372500
length_per_key = self.length_per_key()
25382501
permuted_keys: List[str] = []
2502+
permuted_stride_per_key_per_rank: List[List[int]] = []
25392503
permuted_length_per_key: List[int] = []
25402504
permuted_length_per_key_sum = 0
25412505
for index in indices:
25422506
key = self.keys()[index]
25432507
permuted_keys.append(key)
25442508
permuted_length_per_key.append(length_per_key[index])
2545-
2546-
stride_per_key = self._stride_per_key
2547-
permuted_stride_per_key = (
2548-
[stride_per_key[i] for i in indices] if stride_per_key is not None else None
2549-
)
2550-
2551-
_stride_per_key_per_rank = self._stride_per_key_per_rank
2552-
permuted_stride_per_key_per_rank = (
2553-
_stride_per_key_per_rank[indices, :]
2554-
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None
2555-
else None
2556-
)
2509+
if self.variable_stride_per_key():
2510+
permuted_stride_per_key_per_rank.append(
2511+
self.stride_per_key_per_rank()[index]
2512+
)
25572513

25582514
permuted_length_per_key_sum = sum(permuted_length_per_key)
25592515
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2605,16 +2561,18 @@ def permute(
26052561
self.weights_or_none(),
26062562
permuted_length_per_key_sum,
26072563
)
2608-
2564+
stride_per_key_per_rank = (
2565+
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2566+
)
26092567
kjt = KeyedJaggedTensor(
26102568
keys=permuted_keys,
26112569
values=permuted_values,
26122570
weights=permuted_weights,
26132571
lengths=permuted_lengths.view(-1),
26142572
offsets=None,
26152573
stride=self._stride,
2616-
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
2617-
stride_per_key=permuted_stride_per_key,
2574+
stride_per_key_per_rank=stride_per_key_per_rank,
2575+
stride_per_key=None,
26182576
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
26192577
lengths_offset_per_key=None,
26202578
offset_per_key=None,
@@ -2933,7 +2891,7 @@ def dist_init(
29332891

29342892
if variable_stride_per_key:
29352893
assert stride_per_rank_per_key is not None
2936-
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
2894+
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
29372895
num_workers, len(keys)
29382896
).T.cpu()
29392897

@@ -2970,18 +2928,23 @@ def dist_init(
29702928
weights,
29712929
)
29722930

2973-
if stride_per_key_per_rank.numel() == 0:
2974-
stride_per_key_per_rank = torch.zeros(
2975-
(len(keys), 1), device="cpu", dtype=torch.int64
2976-
)
2931+
stride_per_key_per_rank = torch.jit.annotate(
2932+
List[List[int]], stride_per_key_per_rank_tensor.tolist()
2933+
)
29772934

2935+
if not stride_per_key_per_rank:
2936+
stride_per_key_per_rank = [[0]] * len(keys)
29782937
if stagger > 1:
2938+
stride_per_key_per_rank_stagger: List[List[int]] = []
29792939
local_world_size = num_workers // stagger
2980-
indices = [
2981-
list(range(i, num_workers, local_world_size))
2982-
for i in range(local_world_size)
2983-
]
2984-
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
2940+
for i in range(len(keys)):
2941+
stride_per_rank_stagger: List[int] = []
2942+
for j in range(local_world_size):
2943+
stride_per_rank_stagger.extend(
2944+
stride_per_key_per_rank[i][j::local_world_size]
2945+
)
2946+
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
2947+
stride_per_key_per_rank = stride_per_key_per_rank_stagger
29852948

29862949
kjt = KeyedJaggedTensor(
29872950
keys=keys,

0 commit comments

Comments
 (0)