Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
#
# ----------------------------------------------------------------------------

import gc
import logging
from typing import Optional, Tuple

import numpy as np
from onnx import ModelProto, external_data_helper, numpy_helper

from QEfficient.utils.constants import ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: TRANSFORM - spell check

logger = logging.getLogger(__name__)


class OnnxTransform:
"""
Expand All @@ -31,6 +37,27 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
raise NotImplementedError("Use subclasses for ONNX transform")

@classmethod
def _check_external_data_loaded(cls, model: ModelProto) -> bool:
"""
Check if external data is already loaded in the model.

:param model: The ONNX model to check
:returns: True if external data is already loaded, False otherwise
"""
for tensor in external_data_helper._get_all_tensors(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we think of skipping this extra loop for checking whether for all the tensors external data has been loaded or not. The place where we are loading the external data there we can maintain a flag. This flag by default will be set to false and then once all the external data is loaded we can mark it to TRUE. Then in code we may have to just check the flag. or may not need this function if you want to directly use the flag.

# Check if tensor has external data but no raw data loaded
if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"):
return False
return True

@classmethod
def _cleanup_memory(cls):
"""
Force garbage collection to free up memory after tensor processing.
"""
gc.collect()


class FP16ClipTransform(OnnxTransform):
"""
Expand All @@ -47,6 +74,7 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar
fp16_min = finfo.min
transformed = False

processed_count = 0
for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
Expand All @@ -61,6 +89,15 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar
tensor.CopyFrom(new_tensor)
transformed = True

del neg_inf_mask, clipped_tensor, new_tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this loop itself you can check and then update flag

del nptensor
processed_count += 1

if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0:
cls._cleanup_memory()

cls._cleanup_memory()
return model, transformed


Expand Down Expand Up @@ -89,7 +126,16 @@ def apply(
file_num = 0
current_file_size = 0
transformed = False
external_data_helper.load_external_data_for_model(model, onnx_base_dir)

# Check if external data is already loaded to avoid redundant loading
external_data_already_loaded = cls._check_external_data_loaded(model)

if not external_data_already_loaded:
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
else:
logger.info("External data already loaded, skipping redundant load operation")

processed_count = 0
for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
transformed = True
Expand All @@ -98,4 +144,11 @@ def apply(
file_num += 1
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")

processed_count += 1
if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0:
cls._cleanup_memory()

cls._cleanup_memory()

return model, transformed
2 changes: 2 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def get_models_dir():
ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99
ONNX_EXPORT_OPSET = 13

ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL = 100

COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]

# InternVL constants
Expand Down
Loading