@@ -1098,15 +1098,13 @@ def _maybe_compute_stride_kjt(
1098
1098
stride : Optional [int ],
1099
1099
lengths : Optional [torch .Tensor ],
1100
1100
offsets : Optional [torch .Tensor ],
1101
- stride_per_key_per_rank : Optional [torch . IntTensor ],
1101
+ stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1102
1102
) -> int :
1103
1103
if stride is None :
1104
1104
if len (keys ) == 0 :
1105
1105
stride = 0
1106
- elif (
1107
- stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1108
- ):
1109
- stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106
+ elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1107
+ stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1110
1108
elif offsets is not None and offsets .numel () > 0 :
1111
1109
stride = (offsets .numel () - 1 ) // len (keys )
1112
1110
elif lengths is not None :
@@ -1485,8 +1483,8 @@ def _strides_from_kjt(
1485
1483
def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
1486
1484
# empty like function fx wrapped, also avoids device hardcoding
1487
1485
stride , stride_per_key_per_rank = (
1488
- (None , kjt ._stride_per_key_per_rank )
1489
- if kjt ._stride_per_key_per_rank is not None and kjt . variable_stride_per_key ()
1486
+ (None , kjt .stride_per_key_per_rank () )
1487
+ if kjt .variable_stride_per_key ()
1490
1488
else (kjt .stride (), None )
1491
1489
)
1492
1490
@@ -1672,20 +1670,14 @@ def _maybe_compute_lengths_offset_per_key(
1672
1670
1673
1671
def _maybe_compute_stride_per_key (
1674
1672
stride_per_key : Optional [List [int ]],
1675
- stride_per_key_per_rank : Optional [torch . IntTensor ],
1673
+ stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1676
1674
stride : Optional [int ],
1677
1675
keys : List [str ],
1678
1676
) -> Optional [List [int ]]:
1679
1677
if stride_per_key is not None :
1680
1678
return stride_per_key
1681
1679
elif stride_per_key_per_rank is not None :
1682
- if stride_per_key_per_rank .dim () != 2 :
1683
- # after permute the kjt could be empty
1684
- return []
1685
- rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1686
- if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1687
- pt2_checks_all_is_size (rt )
1688
- return rt
1680
+ return [sum (s ) for s in stride_per_key_per_rank ]
1689
1681
elif stride is not None :
1690
1682
return [stride ] * len (keys )
1691
1683
else :
@@ -1776,9 +1768,7 @@ def __init__(
1776
1768
lengths : Optional [torch .Tensor ] = None ,
1777
1769
offsets : Optional [torch .Tensor ] = None ,
1778
1770
stride : Optional [int ] = None ,
1779
- stride_per_key_per_rank : Optional [
1780
- Union [torch .IntTensor , List [List [int ]]]
1781
- ] = None ,
1771
+ stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1782
1772
# Below exposed to ensure torch.script-able
1783
1773
stride_per_key : Optional [List [int ]] = None ,
1784
1774
length_per_key : Optional [List [int ]] = None ,
@@ -1800,14 +1790,8 @@ def __init__(
1800
1790
self ._lengths : Optional [torch .Tensor ] = lengths
1801
1791
self ._offsets : Optional [torch .Tensor ] = offsets
1802
1792
self ._stride : Optional [int ] = stride
1803
- if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1804
- # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1805
- # does not take List[List[int]]
1806
- assert not isinstance (stride_per_key_per_rank , list )
1807
- self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1808
- torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1809
- if isinstance (stride_per_key_per_rank , list )
1810
- else stride_per_key_per_rank
1793
+ self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1794
+ stride_per_key_per_rank
1811
1795
)
1812
1796
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1813
1797
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1818,8 +1802,6 @@ def __init__(
1818
1802
self ._inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = (
1819
1803
inverse_indices
1820
1804
)
1821
- # this is only needed for torch.compile case
1822
- self ._pt2_stride_per_key_per_rank : Optional [List [List [int ]]] = None
1823
1805
1824
1806
# legacy attribute, for backward compatabilibity
1825
1807
self ._variable_stride_per_key : Optional [bool ] = None
@@ -1835,6 +1817,10 @@ def _init_pt2_checks(self) -> None:
1835
1817
return
1836
1818
if self ._stride_per_key is not None :
1837
1819
pt2_checks_all_is_size (self ._stride_per_key )
1820
+ if self ._stride_per_key_per_rank is not None :
1821
+ # pyre-ignore [16]
1822
+ for s in self ._stride_per_key_per_rank :
1823
+ pt2_checks_all_is_size (s )
1838
1824
1839
1825
@staticmethod
1840
1826
def from_offsets_sync (
@@ -2044,7 +2030,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2044
2030
kjt_stride , kjt_stride_per_key_per_rank = (
2045
2031
(stride_per_key [0 ], None )
2046
2032
if all (s == stride_per_key [0 ] for s in stride_per_key )
2047
- else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2033
+ else (None , [[ stride ] for stride in stride_per_key ] )
2048
2034
)
2049
2035
kjt = KeyedJaggedTensor (
2050
2036
keys = kjt_keys ,
@@ -2209,32 +2195,12 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2209
2195
Returns:
2210
2196
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2211
2197
"""
2212
- # making a local reference to the class variable to make jit.script behave
2213
- _stride_per_key_per_rank = self ._stride_per_key_per_rank
2214
- if (
2215
- not torch .jit .is_scripting ()
2216
- and is_torchdynamo_compiling ()
2217
- and _stride_per_key_per_rank is not None
2218
- ):
2219
- if self ._pt2_stride_per_key_per_rank is not None :
2220
- return self ._pt2_stride_per_key_per_rank
2221
- stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2222
- for stride_per_rank in stride_per_key_per_rank :
2223
- pt2_checks_all_is_size (stride_per_rank )
2224
- self ._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2225
- return stride_per_key_per_rank
2226
- return (
2227
- []
2228
- if _stride_per_key_per_rank is None
2229
- else _stride_per_key_per_rank .tolist ()
2230
- )
2198
+ stride_per_key_per_rank = self ._stride_per_key_per_rank
2199
+ return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2231
2200
2232
2201
def variable_stride_per_key (self ) -> bool :
2233
2202
"""
2234
2203
Returns whether the KeyedJaggedTensor has variable stride per key.
2235
- NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2236
- is not `None`. It might be assigned to False externally/intentionally, usually the
2237
- `self._stride_per_key_per_rank` is trivial.
2238
2204
2239
2205
Returns:
2240
2206
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2379,16 +2345,13 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2379
2345
start_offset = 0
2380
2346
_length_per_key = self .length_per_key ()
2381
2347
_offset_per_key = self .offset_per_key ()
2382
- # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2383
- _stride_per_key_per_rank = self ._stride_per_key_per_rank
2384
2348
for segment in segments :
2385
2349
end = start + segment
2386
2350
end_offset = _offset_per_key [end ]
2387
2351
keys : List [str ] = self ._keys [start :end ]
2388
2352
stride_per_key_per_rank = (
2389
- _stride_per_key_per_rank [start :end , : ]
2353
+ self . stride_per_key_per_rank () [start :end ]
2390
2354
if self .variable_stride_per_key ()
2391
- and _stride_per_key_per_rank is not None
2392
2355
else None
2393
2356
)
2394
2357
if segment == len (self ._keys ):
@@ -2536,24 +2499,17 @@ def permute(
2536
2499
2537
2500
length_per_key = self .length_per_key ()
2538
2501
permuted_keys : List [str ] = []
2502
+ permuted_stride_per_key_per_rank : List [List [int ]] = []
2539
2503
permuted_length_per_key : List [int ] = []
2540
2504
permuted_length_per_key_sum = 0
2541
2505
for index in indices :
2542
2506
key = self .keys ()[index ]
2543
2507
permuted_keys .append (key )
2544
2508
permuted_length_per_key .append (length_per_key [index ])
2545
-
2546
- stride_per_key = self ._stride_per_key
2547
- permuted_stride_per_key = (
2548
- [stride_per_key [i ] for i in indices ] if stride_per_key is not None else None
2549
- )
2550
-
2551
- _stride_per_key_per_rank = self ._stride_per_key_per_rank
2552
- permuted_stride_per_key_per_rank = (
2553
- _stride_per_key_per_rank [indices , :]
2554
- if self .variable_stride_per_key () and _stride_per_key_per_rank is not None
2555
- else None
2556
- )
2509
+ if self .variable_stride_per_key ():
2510
+ permuted_stride_per_key_per_rank .append (
2511
+ self .stride_per_key_per_rank ()[index ]
2512
+ )
2557
2513
2558
2514
permuted_length_per_key_sum = sum (permuted_length_per_key )
2559
2515
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2605,16 +2561,18 @@ def permute(
2605
2561
self .weights_or_none (),
2606
2562
permuted_length_per_key_sum ,
2607
2563
)
2608
-
2564
+ stride_per_key_per_rank = (
2565
+ permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2566
+ )
2609
2567
kjt = KeyedJaggedTensor (
2610
2568
keys = permuted_keys ,
2611
2569
values = permuted_values ,
2612
2570
weights = permuted_weights ,
2613
2571
lengths = permuted_lengths .view (- 1 ),
2614
2572
offsets = None ,
2615
2573
stride = self ._stride ,
2616
- stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2617
- stride_per_key = permuted_stride_per_key ,
2574
+ stride_per_key_per_rank = stride_per_key_per_rank ,
2575
+ stride_per_key = None ,
2618
2576
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2619
2577
lengths_offset_per_key = None ,
2620
2578
offset_per_key = None ,
@@ -2933,7 +2891,7 @@ def dist_init(
2933
2891
2934
2892
if variable_stride_per_key :
2935
2893
assert stride_per_rank_per_key is not None
2936
- stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
2894
+ stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
2937
2895
num_workers , len (keys )
2938
2896
).T .cpu ()
2939
2897
@@ -2970,18 +2928,23 @@ def dist_init(
2970
2928
weights ,
2971
2929
)
2972
2930
2973
- if stride_per_key_per_rank .numel () == 0 :
2974
- stride_per_key_per_rank = torch .zeros (
2975
- (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2976
- )
2931
+ stride_per_key_per_rank = torch .jit .annotate (
2932
+ List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2933
+ )
2977
2934
2935
+ if not stride_per_key_per_rank :
2936
+ stride_per_key_per_rank = [[0 ]] * len (keys )
2978
2937
if stagger > 1 :
2938
+ stride_per_key_per_rank_stagger : List [List [int ]] = []
2979
2939
local_world_size = num_workers // stagger
2980
- indices = [
2981
- list (range (i , num_workers , local_world_size ))
2982
- for i in range (local_world_size )
2983
- ]
2984
- stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
2940
+ for i in range (len (keys )):
2941
+ stride_per_rank_stagger : List [int ] = []
2942
+ for j in range (local_world_size ):
2943
+ stride_per_rank_stagger .extend (
2944
+ stride_per_key_per_rank [i ][j ::local_world_size ]
2945
+ )
2946
+ stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2947
+ stride_per_key_per_rank = stride_per_key_per_rank_stagger
2985
2948
2986
2949
kjt = KeyedJaggedTensor (
2987
2950
keys = keys ,
0 commit comments