Skip to content

Backward compatibility for corrdiff checkpoints #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Added backward compatibility utilities to load existing UNet and SongUnet
checkpoints (e.g. used in CorrDiff)

### Security

### Dependencies
Expand Down
10 changes: 6 additions & 4 deletions examples/generative/corrdiff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,15 @@ model. During training, you can fine-tune various parameters. The most commonly

> **Note on Patch Size Selection**
> When implementing a patch-based training, choosing the right patch size is critical for model performance. The patch dimensions are controlled by `patch_shape_x` and `patch_shape_y` in your configuration file. To determine optimal patch sizes:
> 1. Calculate the auto-correlation function of your data using the provided utilities in [`inference/power_spectra.py`](./inference/power_spectra.py):
> 1. Train a regression model on the full domain.
> 2. Compute the residuals `x_res = x_data - regression_model(x_data)` on multiple samples, where `x_data` are ground truth samples.
> 3. Calculate the auto-correlation function of your residuals using the provided utilities in [`inference/power_spectra.py`](./inference/power_spectra.py):
> - `average_power_spectrum()`
> - `power_spectra_to_acf()`
> 2. Set patch dimensions to match or exceed the distance at which auto-correlation approaches zero
> 3. This ensures each patch captures the full spatial correlation structure of your data
> 4. Set patch dimensions to match or exceed the distance at which auto-correlation approaches zero.
> 5. This ensures each patch captures the full spatial correlation structure of your data.
>
> This analysis helps balance computational efficiency with the preservation of important physical relationships in your data.
> This analysis helps balance computational efficiency with the preservation of local physical relationships in your data.

### Generation configuration

Expand Down
27 changes: 24 additions & 3 deletions physicsnemo/models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch

import physicsnemo
import physicsnemo.models.utils_compatibility as bwc
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.registry import ModelRegistry
from physicsnemo.utils.filesystem import _download_cached, _get_fs
Expand Down Expand Up @@ -120,7 +121,11 @@ def _safe_members(tar, local_path):
print(f"Skipping potentially malicious file: {member.name}")

@classmethod
def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module":
def instantiate(
cls,
arg_dict: Dict[str, Any],
backward_compatibility: bool = False,
) -> "Module":
"""Instantiate a model from a dictionary of arguments

Parameters
Expand All @@ -131,6 +136,9 @@ def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module":
are used to import the class and the last is used to instantiate
the class. The '__args__' key should be a dictionary of arguments
to pass to the class's __init__ function.
backward_compatibility : bool, optional
Whether to apply backward compatibility patches to the arguments, by
default False.

Returns
-------
Expand Down Expand Up @@ -162,6 +170,10 @@ def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module":
)
"""

# Backward compatibility: handle old checkpoints (class renamed)
if backward_compatibility:
bwc._update_class_name(arg_dict)

_cls_name = arg_dict["__name__"]
registry = ModelRegistry()
if cls.__name__ == arg_dict["__name__"]: # If cls is the class
Expand Down Expand Up @@ -192,6 +204,10 @@ def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module":
if isinstance(_cls, importlib.metadata.EntryPoint):
_cls = _cls.load()

# Backward compatibility: handle old checkpoints (args renamed)
if backward_compatibility:
bwc._update_init_args(_cls, arg_dict["__args__"])

return _cls(**arg_dict["__args__"])

def debug(self):
Expand Down Expand Up @@ -337,13 +353,18 @@ def load(
self.load_state_dict(model_dict, strict=strict)

@classmethod
def from_checkpoint(cls, file_name: str) -> "Module":
def from_checkpoint(
cls, file_name: str, backward_compatibility: bool = False
) -> "Module":
"""Simple utility for constructing a model from a checkpoint

Parameters
----------
file_name : str
Checkpoint file name
backward_compatibility : bool, optional
Whether to apply backward compatibility patches to the arguments, by
default False.

Returns
-------
Expand Down Expand Up @@ -374,7 +395,7 @@ def from_checkpoint(cls, file_name: str) -> "Module":
# Load model arguments and instantiate the model
with open(local_path.joinpath("args.json"), "r") as f:
args = json.load(f)
model = cls.instantiate(args)
model = cls.instantiate(args, backward_compatibility=backward_compatibility)

# Load the model weights
model_dict = torch.load(
Expand Down
95 changes: 95 additions & 0 deletions physicsnemo/models/utils_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from typing import Any, Dict, Type

# -- Diffusion UNet prior to 327d9928abc17983ad7aa3df94da9566c197c468 -- #
# For args renaming/deletion
old_to_new_args_UNet_327d9928 = {
"img_channels": None,
"sigma_min": None,
"sigma_max": None,
"sigma_data": None,
}
# For unpacked kwargs dict
old_to_new_kwargs_UNet_327d9928 = {} # No unpacked kwargs for UNet

# -- EDMPrecondSuperResolution prior to 327d9928abc17983ad7aa3df94da9566c197c468 -- #
# For args renaming/deletion
old_to_new_args_EDMPrecondSuperResolution_327d9928 = {"img_channels": None}
# For unpacked kwargs dict
old_to_new_kwargs_EDMPrecondSuperResolution_327d9928 = {}
# For class renaming
old_to_new_class_name_EDMPrecondSuperResolution_327d9928 = {
"__name__": "EDMPrecondSuperResolution"
}


def _update_args(args, old_to_new_args):
for k, v in old_to_new_args.items():
if v is not None:
args[v] = args.pop(k)
else:
del args[k]


def _update_init_args(cls: Type, args: Dict[str, Any]):
"""Update arguments passed to instantiation of a class for backward
compatibility. Handles arguments that have been deprecated or renamed.

Parameters
----------
- cls : type
The class to filter arguments for.
- args : dict
The arguments passed to cls.__init__ that need to be filtered.
"""
# Diffusion UNet prior to 327d9928abc17983ad7aa3df94da9566c197c468
diffusion_module = importlib.import_module("physicsnemo.models.diffusion")
if cls is diffusion_module.UNet and all(
k in args for k in old_to_new_args_UNet_327d9928
):
_update_args(args, old_to_new_args_UNet_327d9928)
return
# EDMPrecondSuperResolution prior to 327d9928abc17983ad7aa3df94da9566c197c468
if cls is diffusion_module.EDMPrecondSuperResolution and all(
k in args for k in old_to_new_args_EDMPrecondSuperResolution_327d9928
):
_update_args(args, old_to_new_args_EDMPrecondSuperResolution_327d9928)
_update_args(args, old_to_new_kwargs_EDMPrecondSuperResolution_327d9928)
return


def _update_class_name(arg_dict: Dict[str, Any]):
"""Update the class name of classes that have been renamed.

Parameters
----------
arg_dict : dict
The argument dictionary to update. It should contain a "__name__" key
that represents the class name and and "__args__" key that represents
the arguments passed to the class constructor.
"""
# EDMPrecondSuperResolution prior to 327d9928abc17983ad7aa3df94da9566c197c468
if arg_dict["__name__"] == "EDMPrecondSR" and all(
k in arg_dict["__args__"]
for k in old_to_new_args_EDMPrecondSuperResolution_327d9928
):
arg_dict["__name__"] = old_to_new_class_name_EDMPrecondSuperResolution_327d9928[
"__name__"
]
return