Skip to content

Commit

Permalink
Fix is_using_oneflow_backend check (#1112)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced backend detection logic for improved compatibility with the
OneFlow library.
- Added a function to check for OneFlow library availability and CUDA
support.

- **Bug Fixes**
- Improved messaging for cases when the OneFlow backend is not detected.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
ccssu authored Sep 20, 2024
1 parent 7c32525 commit 9231f55
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Union

import oneflow

import torch
from comfy import model_management
from comfy.model_base import BaseModel, SVD_img2vid
Expand All @@ -9,6 +11,7 @@
OneflowDeployableModule as DeployableModule,
)
from onediff.utils import set_boolean_env_var
from onediff.utils.import_utils import is_oneflow_available

from ..patch_management import create_patch_executor, PatchType

Expand Down Expand Up @@ -63,6 +66,15 @@ def set_environment_for_svd_img2vid(model: ModelPatcher):


def is_using_oneflow_backend(module):
# First, check if oneflow is available and CUDA is enabled
if is_oneflow_available() and not oneflow.cuda.is_available():
print("OneFlow CUDA support is not available")
return False

# Check if the module
if isinstance(module, oneflow.nn.Module):
return True

dc_patch_executor = create_patch_executor(PatchType.DCUNetExecutorPatch)
if isinstance(module, ModelPatcher):
deep_cache_module = dc_patch_executor.get_patch(module)
Expand All @@ -85,7 +97,17 @@ def is_using_oneflow_backend(module):
if isinstance(module, DeployableModule):
return True

raise RuntimeError("")
if hasattr(module, "parameters"):
for param in module.parameters():
if isinstance(param, oneflow.Tensor):
return True

warn_msg = (
f"OneFlow backend is not detected for the module, the module is {type(module)}"
)
print(warn_msg)
# If none of the above conditions are met, it's not using OneFlow backend
return False


def clear_deployable_module_cache_and_unbind(
Expand Down

0 comments on commit 9231f55

Please sign in to comment.