Skip to content

Commit

Permalink
[NEVIS] Minor changes to ckpt loading logic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494966802
Change-Id: I806a4c0a67c6f5f5de0619dea328cbac18ec9511
  • Loading branch information
LSCLDev authored and arthurdouillard committed Dec 13, 2022
1 parent d109541 commit 6de783b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
7 changes: 4 additions & 3 deletions experiments_jax/configs/finetuning_ind_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dm_nevis.benchmarker.datasets import test_stream
from dm_nevis.benchmarker.environment import logger_utils
from experiments_jax.environment import checkpoint_loader
from experiments_jax.environment import pretrained_model_loader
from experiments_jax.learners.finetuning import finetuning_learner
from experiments_jax.training import augmentations
from experiments_jax.training import modules
Expand All @@ -33,7 +33,7 @@
DEFAULT_CHECKPOINT_DIR = os.environ.get('NEVIS_CHECKPOINT_DIR',
'/tmp/nevis_checkpoint_dir')
DEFAULT_PRETRAIN_CHECKPOINT_PATH = os.path.join(DEFAULT_CHECKPOINT_DIR,
'pretraining.ckpt')
'pretraining.pkl')

FREEZE_PRETRAINED_BACKBONE = False

Expand Down Expand Up @@ -95,7 +95,8 @@ def get_config() -> ml_collections.ConfigDict:
# Optionally load and/or freeze pretrained parameters.
'load_params_fn': None,
'load_params_fn_with_kwargs': {
'fun': checkpoint_loader.load_ckpt_params,
'fun':
pretrained_model_loader.load_model_params_from_ckpt,
'kwargs': {
'freeze_pretrained_backbone':
FREEZE_PRETRAINED_BACKBONE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from absl import logging
import chex
from experiments_jax.environment import pickle_checkpointer
from experiments_jax.training import trainer
import haiku as hk


def load_ckpt_params(
def load_model_params_from_ckpt(
params: hk.Params,
state: hk.State,
freeze_pretrained_backbone: bool = False,
Expand All @@ -26,12 +26,15 @@ def load_ckpt_params(
updated params split into trainable and frozen, updated states.
"""

checkpointer = pickle_checkpointer.PickleCheckpointer(checkpoint_path)
restored_params = checkpointer.restore()

if restored_params is None:
trainer_state = trainer.restore_train_state(checkpoint_path)
if trainer_state is None or trainer_state.trainable_params is None or trainer_state.frozen_params is None:
return params, {}, state

restored_params = {
**trainer_state.trainable_params,
**trainer_state.frozen_params
}

def filter_fn(module_name, *unused_args):
del unused_args
return module_name.startswith('backbone')
Expand Down
2 changes: 1 addition & 1 deletion experiments_jax/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def init_train_state(
if load_params_fn:
trainable_params, frozen_params, state = load_params_fn(params, state)
else:
trainable_params, frozen_params = params, []
trainable_params, frozen_params = params, {}

opt_state = opt.init(trainable_params)

Expand Down
7 changes: 4 additions & 3 deletions experiments_torch/configs/finetuning_ind_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os

from dm_nevis.benchmarker.environment import logger_utils
from experiments_torch.environment import checkpoint_loader
from experiments_torch.environment import pretrained_model_loader
from experiments_torch.learners.finetuning import finetuning_learner
from experiments_torch.training import augmentations
from experiments_torch.training import resnet
Expand All @@ -31,7 +31,7 @@
DEFAULT_CHECKPOINT_DIR = os.environ.get('NEVIS_CHECKPOINT_DIR',
'/tmp/nevis_checkpoint_dir')
DEFAULT_PRETRAIN_CHECKPOINT_PATH = os.path.join(DEFAULT_CHECKPOINT_DIR,
'pretraining.ckpt')
'pretraining.pkl')

FREEZE_PRETRAINED_BACKBONE = False

Expand Down Expand Up @@ -93,7 +93,8 @@ def get_config() -> ml_collections.ConfigDict:
# Optionally load and/or freeze pretrained parameters.
'load_params_fn': None,
'load_params_fn_with_kwargs': {
'fun': checkpoint_loader.load_ckpt_params,
'fun':
pretrained_model_loader.load_model_params_from_ckpt,
'kwargs': {
'freeze_pretrained_backbone':
FREEZE_PRETRAINED_BACKBONE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Tuple, Union, Dict

from absl import logging
from experiments_torch.environment import pickle_checkpointer
from experiments_torch.training import models
from experiments_torch.training import trainer
from torch.nn import parameter


def load_ckpt_params(
def load_model_params_from_ckpt(
model: models.Model,
freeze_pretrained_backbone: bool = False,
checkpoint_path: str = '',
Expand All @@ -23,19 +23,18 @@ def load_ckpt_params(
Returns:
updated params split into trainable and frozen.
"""

checkpointer = pickle_checkpointer.PickleCheckpointer(checkpoint_path)
restored_model = checkpointer.restore()

if restored_model is None:
trainer_state = trainer.restore_train_state(checkpoint_path)
if trainer_state is None or trainer_state.model is None:
return model.backbone.parameters(), {}

restored_model = trainer_state.model

assert isinstance(restored_model, models.Model)
logging.info('Loading pretrained model finished.')

for model_param, restored_model_param in zip(
model.backbone.parameters(), restored_model.backbone.parameters()):
assert model_param.data.shape == restored_model_param.data
assert model_param.data.shape == restored_model_param.data.shape
model_param.data = restored_model_param.data
model_param.requires_grad = not freeze_pretrained_backbone

Expand Down

0 comments on commit 6de783b

Please sign in to comment.