Skip to content

Commit a719848

Browse files
jingshfacebook-github-bot
authored andcommitted
Remove unused function (#3109)
Summary: Pull Request resolved: #3109 Functions is not used anywhere, remove it. - _to_lengths - _merge_weights_or_none - _strides_from_kjt - _maybe_compute_variable_stride_per_key Reviewed By: jcwchen, TroyGarden Differential Revision: D76849500 fbshipit-source-id: c60506ed891ea589966f14698b09618ad770ecfe
1 parent 7c87718 commit a719848

File tree

1 file changed

+0
-43
lines changed

1 file changed

+0
-43
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@ def _to_offsets(lengths: torch.Tensor) -> torch.Tensor:
8080
return torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
8181

8282

83-
def _to_lengths(offsets: torch.Tensor) -> torch.Tensor:
84-
return offsets[1:] - offsets[:-1]
85-
86-
8783
@torch.jit.script_if_tracing
8884
def _batched_lengths_to_offsets(lengths: torch.Tensor) -> torch.Tensor:
8985
(f, b) = lengths.shape
@@ -1452,33 +1448,6 @@ def _maybe_compute_kjt_to_jt_dict(
14521448
return _jt_dict
14531449

14541450

1455-
@torch.fx.wrap
1456-
def _merge_weights_or_none(
1457-
a_weights: Optional[torch.Tensor],
1458-
b_weights: Optional[torch.Tensor],
1459-
) -> Optional[torch.Tensor]:
1460-
assert not (
1461-
(a_weights is None) ^ (b_weights is None)
1462-
), "Can only merge weighted or unweighted KJTs."
1463-
if a_weights is None:
1464-
return None
1465-
# pyre-ignore[6]
1466-
return torch.cat([a_weights, b_weights], dim=0)
1467-
1468-
1469-
@torch.fx.wrap
1470-
def _strides_from_kjt(
1471-
kjt: "KeyedJaggedTensor",
1472-
) -> Tuple[Optional[int], Optional[List[List[int]]]]:
1473-
stride, stride_per_key_per_rank = (
1474-
(None, kjt.stride_per_key_per_rank())
1475-
if kjt.variable_stride_per_key()
1476-
else (kjt.stride(), None)
1477-
)
1478-
1479-
return stride, stride_per_key_per_rank
1480-
1481-
14821451
@torch.fx.wrap
14831452
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14841453
# empty like function fx wrapped, also avoids device hardcoding
@@ -1684,18 +1653,6 @@ def _maybe_compute_stride_per_key(
16841653
return None
16851654

16861655

1687-
def _maybe_compute_variable_stride_per_key(
1688-
variable_stride_per_key: Optional[bool],
1689-
stride_per_key_per_rank: Optional[List[List[int]]],
1690-
) -> bool:
1691-
if variable_stride_per_key is not None:
1692-
return variable_stride_per_key
1693-
elif stride_per_key_per_rank is not None:
1694-
return True
1695-
else:
1696-
return False
1697-
1698-
16991656
class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17001657
"""Represents an (optionally weighted) keyed jagged tensor.
17011658

0 commit comments

Comments
 (0)