diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 3787e821e..b086d9bf6 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -1,5 +1,4 @@ import torch -import traceback from collections import OrderedDict from comfy.model_patcher import ModelPatcher from functools import singledispatch @@ -14,9 +13,11 @@ def switch_to_cached_model(new_model, cached_model): @switch_to_cached_model.register def _(new_model: ModelPatcher, cached_model): - assert type(new_model.model) == type( - cached_model - ), f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}" + if type(new_model.model) != type(cached_model): + raise TypeError( + f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}" + ) + cached_model.diffusion_model.load_state_dict( new_model.model.diffusion_model.state_dict(), strict=True ) @@ -75,8 +76,7 @@ def get_cached_model(self, key, model): try: return switch_to_cached_model(model, cached_model) except Exception as e: - print("An exception occurred when switching to cached model:") - print(traceback.format_exc()) + print(f"An exception occurred when switching to cached model:") del self._cache[key] torch.cuda.empty_cache()