@@ -80,10 +80,6 @@ def _to_offsets(lengths: torch.Tensor) -> torch.Tensor:
80
80
return torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
81
81
82
82
83
- def _to_lengths (offsets : torch .Tensor ) -> torch .Tensor :
84
- return offsets [1 :] - offsets [:- 1 ]
85
-
86
-
87
83
@torch .jit .script_if_tracing
88
84
def _batched_lengths_to_offsets (lengths : torch .Tensor ) -> torch .Tensor :
89
85
(f , b ) = lengths .shape
@@ -1452,33 +1448,6 @@ def _maybe_compute_kjt_to_jt_dict(
1452
1448
return _jt_dict
1453
1449
1454
1450
1455
- @torch .fx .wrap
1456
- def _merge_weights_or_none (
1457
- a_weights : Optional [torch .Tensor ],
1458
- b_weights : Optional [torch .Tensor ],
1459
- ) -> Optional [torch .Tensor ]:
1460
- assert not (
1461
- (a_weights is None ) ^ (b_weights is None )
1462
- ), "Can only merge weighted or unweighted KJTs."
1463
- if a_weights is None :
1464
- return None
1465
- # pyre-ignore[6]
1466
- return torch .cat ([a_weights , b_weights ], dim = 0 )
1467
-
1468
-
1469
- @torch .fx .wrap
1470
- def _strides_from_kjt (
1471
- kjt : "KeyedJaggedTensor" ,
1472
- ) -> Tuple [Optional [int ], Optional [List [List [int ]]]]:
1473
- stride , stride_per_key_per_rank = (
1474
- (None , kjt .stride_per_key_per_rank ())
1475
- if kjt .variable_stride_per_key ()
1476
- else (kjt .stride (), None )
1477
- )
1478
-
1479
- return stride , stride_per_key_per_rank
1480
-
1481
-
1482
1451
@torch .fx .wrap
1483
1452
def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
1484
1453
# empty like function fx wrapped, also avoids device hardcoding
@@ -1684,18 +1653,6 @@ def _maybe_compute_stride_per_key(
1684
1653
return None
1685
1654
1686
1655
1687
- def _maybe_compute_variable_stride_per_key (
1688
- variable_stride_per_key : Optional [bool ],
1689
- stride_per_key_per_rank : Optional [List [List [int ]]],
1690
- ) -> bool :
1691
- if variable_stride_per_key is not None :
1692
- return variable_stride_per_key
1693
- elif stride_per_key_per_rank is not None :
1694
- return True
1695
- else :
1696
- return False
1697
-
1698
-
1699
1656
class KeyedJaggedTensor (Pipelineable , metaclass = JaggedTensorMeta ):
1700
1657
"""Represents an (optionally weighted) keyed jagged tensor.
1701
1658
0 commit comments