diff --git a/torchrec/modules/hash_mc_metrics.py b/torchrec/modules/hash_mc_metrics.py index 0f5054727..714cf8c2a 100644 --- a/torchrec/modules/hash_mc_metrics.py +++ b/torchrec/modules/hash_mc_metrics.py @@ -46,11 +46,10 @@ def __init__( ) -> None: super().__init__() - # persist scalar logger steps in checkpoint to make sure it is not reset after training job restarted self.register_buffer( ScalarLogger.STEPS_BUFFER, torch.tensor(1, dtype=torch.int64), - persistent=True, + persistent=False, ) self._name: str = name