diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py index 3f00d5275..23b7ebdd0 100644 --- a/torchrec/sparse/tensor_dict.py +++ b/torchrec/sparse/tensor_dict.py @@ -15,6 +15,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +@torch.fx.wrap def maybe_td_to_kjt( features: KeyedJaggedTensor, keys: Optional[List[str]] = None ) -> KeyedJaggedTensor: