Skip to content

Backward compatibility for corrdiff checkpoints #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

CharlelieLrt
Copy link
Collaborator

@CharlelieLrt CharlelieLrt commented Apr 25, 2025

PhysicsNeMo Pull Request

Description

This PR introduces a patch to be able to load existing UNet and SongUNet checkpoints, similar to those used in CorrDiff.

To apply backward compatibility patches and be able to load old checkpoints, run:

python

model = Module.from_checkpoint(chckpt.mdlus, backward_compatibility=True)

Bu default backward_compatibility=False, in which case the checkpoint loader will raise an error when attempting to load an old checkpoint.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

None.

@CharlelieLrt CharlelieLrt self-assigned this Apr 25, 2025
@CharlelieLrt CharlelieLrt added the 3 - Ready for Review Ready for review by team label Apr 25, 2025
@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Collaborator Author

@pzharrington
Following our discussion offline, I attempted a refactor based on a function to explicitly convert checkpoints such as convert_checkpoint(old_chckp.mdlus, new_chckpt.mdlus). But I realized that this approach might not work for 2 reasons:

  1. In utils_compatibility.py I need a statement if cls is ... to identify the class for which we are applying the patch. This means that this can only be done after cls has been determined, so within the instantiate. This cannot be done in a separate convert function.
  2. If going for an explicit convert function, we will need to update all our affected checkpoints hosted on NGC (we can't really expect users to always run convert for those). But updating multiple checkpoints on NGC could quickly becomes unmanageable if more checkpoints are affected.

So, I proposed another refactor that does essentially the same thing, but that is enabled by a keyword argument passed to load_checkpoint: model = Module.from_checkpoint(chckpt.mdlus, backward_compatibility=True).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant