-
Notifications
You must be signed in to change notification settings - Fork 51
Added memory optimization for onnx transforms #538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,17 @@ | |
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
import gc | ||
import logging | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
from onnx import ModelProto, external_data_helper, numpy_helper | ||
|
||
from QEfficient.utils.constants import ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OnnxTransform: | ||
""" | ||
|
@@ -31,6 +37,27 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | |
""" | ||
raise NotImplementedError("Use subclasses for ONNX transform") | ||
|
||
@classmethod | ||
def _check_external_data_loaded(cls, model: ModelProto) -> bool: | ||
""" | ||
Check if external data is already loaded in the model. | ||
|
||
:param model: The ONNX model to check | ||
:returns: True if external data is already loaded, False otherwise | ||
""" | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we think of skipping this extra loop for checking whether for all the tensors external data has been loaded or not. The place where we are loading the external data there we can maintain a flag. This flag by default will be set to false and then once all the external data is loaded we can mark it to TRUE. Then in code we may have to just check the flag. or may not need this function if you want to directly use the flag. |
||
# Check if tensor has external data but no raw data loaded | ||
if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"): | ||
return False | ||
return True | ||
|
||
@classmethod | ||
def _cleanup_memory(cls): | ||
""" | ||
Force garbage collection to free up memory after tensor processing. | ||
""" | ||
gc.collect() | ||
|
||
|
||
class FP16ClipTransform(OnnxTransform): | ||
""" | ||
|
@@ -47,6 +74,7 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar | |
fp16_min = finfo.min | ||
transformed = False | ||
|
||
processed_count = 0 | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
nptensor = numpy_helper.to_array(tensor, onnx_base_dir) | ||
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): | ||
|
@@ -61,6 +89,15 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar | |
tensor.CopyFrom(new_tensor) | ||
transformed = True | ||
|
||
del neg_inf_mask, clipped_tensor, new_tensor | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this loop itself you can check and then update flag |
||
del nptensor | ||
processed_count += 1 | ||
|
||
if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0: | ||
cls._cleanup_memory() | ||
|
||
cls._cleanup_memory() | ||
return model, transformed | ||
|
||
|
||
|
@@ -89,7 +126,16 @@ def apply( | |
file_num = 0 | ||
current_file_size = 0 | ||
transformed = False | ||
external_data_helper.load_external_data_for_model(model, onnx_base_dir) | ||
|
||
# Check if external data is already loaded to avoid redundant loading | ||
external_data_already_loaded = cls._check_external_data_loaded(model) | ||
|
||
if not external_data_already_loaded: | ||
external_data_helper.load_external_data_for_model(model, onnx_base_dir) | ||
else: | ||
logger.info("External data already loaded, skipping redundant load operation") | ||
|
||
processed_count = 0 | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): | ||
transformed = True | ||
|
@@ -98,4 +144,11 @@ def apply( | |
file_num += 1 | ||
current_file_size = tsize | ||
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
|
||
processed_count += 1 | ||
if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0: | ||
cls._cleanup_memory() | ||
|
||
cls._cleanup_memory() | ||
|
||
return model, transformed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: TRANSFORM - spell check