Skip to content

Commit

Permalink
additional logging
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhuteja12 committed Oct 6, 2023
1 parent 9bfcee4 commit dd3333f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 6 deletions.
13 changes: 9 additions & 4 deletions src/renate/updaters/model_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from renate.utils.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from renate.utils.distributed_strategies import create_strategy
from renate.utils.file import unlink_file_or_folder
from renate.utils.misc import int_or_str
from renate.utils.misc import AdditionalTrainingMetrics, int_or_str
from .learner import Learner, ReplayLearner
from ..models import RenateModule


logging_logger = logging.getLogger(__name__)


Expand All @@ -40,28 +41,32 @@ def __init__(self, val_enabled: bool):
super().__init__()
self._report = Reporter()
self._val_enabled = val_enabled
self._additional_metrics = AdditionalTrainingMetrics()

@rank_zero_only
def _log(self, trainer: Trainer, training: bool) -> None:
def _log(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Report the current epoch's results to Syne Tune.
If validation was run `_val_enabled` is True, the results are reported at the end of
the validation epoch. Otherwise, they are reported at the end of the training epoch.
"""

training = pl_module.training
if trainer.sanity_checking or (training and self._val_enabled):
return
to_report = {k: v.item() for k, v in trainer.logged_metrics.items()}
to_report.update(self._additional_metrics(pl_module))
self._report(
**{k: v.item() for k, v in trainer.logged_metrics.items()},
step=trainer.current_epoch,
epoch=trainer.current_epoch + 1,
)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._log(trainer=trainer, training=pl_module.training)
self._log(trainer=trainer, pl_module=pl_module)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._log(trainer=trainer, training=pl_module.training)
self._log(trainer=trainer, pl_module=pl_module)


class RenateModelCheckpoint(ModelCheckpoint):
Expand Down
37 changes: 36 additions & 1 deletion src/renate/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Set, Union
import time
from typing import Dict, Optional, Set, Tuple, Union

import torch

Expand Down Expand Up @@ -37,3 +38,37 @@ def maybe_populate_mask_and_ignore_logits(
logits.index_fill_(1, class_mask.to(logits.device), -float("inf"))

return logits, class_mask


class AdditionalTrainingMetrics:
def __init__(self) -> None:
"""We gather memory stats, Time stats, number of trainable params, total params"""
self._train_start_time = time.time()

def __call__(self, model: torch.nn.Module) -> Dict[str, Union[float, int]]:
curr_running_time = time.time() - self._train_start_time
# maximum amount of memory used in training. This might
# not be the best
peak_memory_usage = (
torch.cuda.memory_stats()["allocated_bytes.all.peak"]
if torch.cuda.is_available()
else 0
)
trainable_params, total_params = self.parameters_count(model)

return dict(
curr_running_time=curr_running_time,
peak_memory_usage=peak_memory_usage,
trainable_params=trainable_params,
total_params=total_params,
)

def parameters_count(self, model: torch.nn.Module) -> Tuple[int, int]:
trainable_params, total_params = 0, 0
for param in model.parameters():
num_params = param.numel()
total_params += num_params
if param.requires_grad:
trainable_params += num_params

return trainable_params, total_params
26 changes: 25 additions & 1 deletion test/renate/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import pytest
import torch

from renate.utils.misc import int_or_str, maybe_populate_mask_and_ignore_logits
from renate.utils.misc import (
AdditionalTrainingMetrics,
int_or_str,
maybe_populate_mask_and_ignore_logits,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -69,3 +73,23 @@ def test_possibly_populate_mask_and_ignore_logits(
assert out_logits.data_ptr() == logits.data_ptr()
if class_mask is not None:
assert class_mask.data_ptr() == out_cm.data_ptr()


@pytest.mark.parametrize(
"model,gnd",
[
(
torch.nn.Linear(2, 2),
{"peak_memory_usage": 0, "trainable_params": 6, "total_params": 6},
),
(
torch.nn.Linear(2, 2).requires_grad_(False),
{"peak_memory_usage": 0, "trainable_params": 0, "total_params": 6},
),
],
)
def test_addition_metrics(model, gnd):
metrics = AdditionalTrainingMetrics()
out = metrics(model)
for k in gnd:
assert gnd[k] == out[k]

0 comments on commit dd3333f

Please sign in to comment.