Skip to content

Commit

Permalink
Merge branch 'main' into print-model-analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm authored Feb 2, 2024
2 parents 960a9fa + 1a9ee4b commit e649c45
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/sparsezoo/analyze/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import Any, Dict, List, Optional, Union

import numpy
import onnx
import yaml
from onnx import ModelProto, NodeProto
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt
Expand Down Expand Up @@ -68,6 +67,7 @@
is_parameterized_prunable_layer,
is_quantized_layer,
is_sparse_layer,
load_model,
)


Expand Down Expand Up @@ -914,7 +914,7 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
model_onnx = onnx_file_path
model_name = ""
else:
model_onnx = onnx.load(onnx_file_path)
model_onnx = load_model(onnx_file_path)
model_name = str(onnx_file_path)

model_graph = ONNXGraph(model_onnx)
Expand Down
12 changes: 7 additions & 5 deletions src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import numpy
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
_setup_args = self.initialize_model_from_stub(stub=self.source)
files, path, url, validation_results, compressed_size = _setup_args
if download_path is not None:
download_path = str(Path(download_path).expanduser().resolve())
path = download_path # overwrite cache path with user input
else:
# initializing the model from the path
Expand Down Expand Up @@ -703,23 +705,23 @@ def _download(
)
):
file.download(destination_path=directory_path)
return True
validations = True
else:
_LOGGER.warning(
f"Failed to download file {file.name}. The url of the file is None."
)
return False
validations = False

elif isinstance(file, Recipes):
validations = (
validations = all(
self._download(_file, directory_path) for _file in file.recipes
)

else:
validations = (
validations = all(
self._download(_file, directory_path) for _file in file.values()
)
return all(validations)
return validations

def _sample_outputs_list_to_dict(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/sparsezoo/utils/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def validate_onnx(model: Union[str, ModelProto]):
raise ValueError(f"Invalid onnx model: {err}")


def load_model(model: Union[str, ModelProto]) -> ModelProto:
def load_model(model: Union[str, ModelProto, Path]) -> ModelProto:
"""
Load an ONNX model from an onnx model file path. If a ModelProto
is given, then it is returned.
Expand All @@ -185,7 +185,7 @@ def load_model(model: Union[str, ModelProto]) -> ModelProto:
if isinstance(model, ModelProto):
return model

if isinstance(model, str):
if isinstance(model, (Path, str)):
return onnx.load(clean_path(model))

raise ValueError(f"unknown type given for model: {type(model)}")
Expand Down

0 comments on commit e649c45

Please sign in to comment.