From 6d257617e10cbbe6e1e73da83311e500ccf7e4e8 Mon Sep 17 00:00:00 2001 From: Yury Date: Thu, 22 Dec 2022 17:21:32 +0100 Subject: [PATCH 01/17] add support for sklearn transformers --- mlem/constants.py | 1 + mlem/contrib/scipy.py | 64 +++++++++++++++++++++++++++++++++++ mlem/contrib/sklearn.py | 25 ++++++++++++++ mlem/ext.py | 1 + setup.py | 3 ++ tests/contrib/test_sklearn.py | 39 +++++++++++++++++++-- 6 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 mlem/contrib/scipy.py diff --git a/mlem/constants.py b/mlem/constants.py index 26f3f516..6914825e 100644 --- a/mlem/constants.py +++ b/mlem/constants.py @@ -4,5 +4,6 @@ PREDICT_METHOD_NAME = "predict" PREDICT_PROBA_METHOD_NAME = "predict_proba" PREDICT_ARG_NAME = "data" +TRANSFORM_METHOD_NAME = "transform" MLEM_CONFIG_FILE_NAME = ".mlem.yaml" diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py new file mode 100644 index 00000000..4d23cecf --- /dev/null +++ b/mlem/contrib/scipy.py @@ -0,0 +1,64 @@ +from typing import Any, ClassVar, Iterator, Tuple + +import scipy +from scipy import sparse +from scipy.sparse import csr_matrix + +from mlem.core.artifacts import Artifacts, Storage +from mlem.core.data_type import ( + DT, + DataHook, + DataReader, + DataType, + DataWriter, + WithDefaultSerializer, +) +from mlem.core.hooks import IsInstanceHookMixin +from mlem.core.requirements import InstallableRequirement, Requirements + + +class ScipySparceMatrix( + WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin +): + type: ClassVar[str] = "csr_matrix" + valid_types: ClassVar = csr_matrix + dtype: str + + def get_requirements(self) -> Requirements: + return Requirements.new([InstallableRequirement.from_module(scipy)]) + + @classmethod + def process(cls, obj: Any, **kwargs) -> DataType: + return ScipySparceMatrix(dtype=obj.dtype.name) + + def get_writer( + self, project: str = None, filename: str = None, **kwargs + ) -> DataWriter: + return ScipyWriter(**kwargs) + + +class ScipyWriter(DataWriter): + def write( + self, data: DT, storage: Storage, path: str + ) -> Tuple[DataReader[DT], Artifacts]: + with storage.open(path) as (f, art): + sparse.save_npz(f, art) + return ScipyReader(data_type=data), {self.art_name: art} + + +class ScipyReader(DataReader): + type: ClassVar[str] = "csr_matrix" + + def read_batch( + self, artifacts: Artifacts, batch_size: int + ) -> Iterator[DT]: + raise NotImplementedError + + def read(self, artifacts: Artifacts) -> Iterator[DataType]: + if DataWriter.art_name not in artifacts: + raise ValueError( + f"Wrong artifacts {artifacts}: should be one {DataWriter.art_name} file" + ) + with artifacts[DataWriter.art_name].open() as f: + data = sparse.load_npz(f) + return self.data_type.copy().bind(data) diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index 86aea130..601d0615 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -7,6 +7,7 @@ import sklearn from sklearn.base import ClassifierMixin, RegressorMixin +from sklearn.feature_extraction.text import TransformerMixin, _VectorizerMixin from sklearn.pipeline import Pipeline from mlem.core.hooks import IsInstanceHookMixin @@ -132,3 +133,27 @@ def process( **predict_proba_args ) return mt + + +class SklearnTransformer(SklearnModel): + valid_types: ClassVar = ( + TransformerMixin, + _VectorizerMixin, + ) + type: ClassVar = "sklearn_transformer" + + @classmethod + def process( + cls, obj: Any, sample_data: Optional[Any] = None, **kwargs + ) -> ModelType: + methods = { + "transform": Signature.from_method( + obj.transform, + auto_infer=sample_data is not None, + raw_documents=sample_data, + ), + } + + return SklearnTransformer(io=SimplePickleIO(), methods=methods).bind( + obj + ) diff --git a/mlem/ext.py b/mlem/ext.py index 9408781d..a3d3f237 100644 --- a/mlem/ext.py +++ b/mlem/ext.py @@ -129,6 +129,7 @@ class ExtensionLoader: False, ), Extension("mlem.contrib.git", ["pygit2"], True), + Extension("mlem.contrib.scipy", ["scipy"], False), Extension("mlem.contrib.torchvision", ["torchvision"], False), ) diff --git a/setup.py b/setup.py index 777ce550..8a6386f7 100644 --- a/setup.py +++ b/setup.py @@ -237,6 +237,9 @@ "serializer.xgboost_dmatrix = mlem.contrib.xgboost:DMatrixSerializer", "model_type.xgboost = mlem.contrib.xgboost:XGBoostModel", "model_io.xgboost_io = mlem.contrib.xgboost:XGBoostModelIO", + "data_type.csr_matrix = mlem.contrib.scipy:ScipySparceMatrix", + "data_writer.csr_matrix = mlem.contrib.scipy:ScipyWriter", + "data_reader.csr_matrix = mlem.contrib.scipy:ScipyReader", ], "mlem.config": [ "core = mlem.config:MlemConfig", diff --git a/tests/contrib/test_sklearn.py b/tests/contrib/test_sklearn.py index da456667..62a295c1 100644 --- a/tests/contrib/test_sklearn.py +++ b/tests/contrib/test_sklearn.py @@ -3,14 +3,16 @@ import lightgbm as lgb import numpy as np import pytest +from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC -from mlem.constants import PREDICT_METHOD_NAME +from mlem.constants import PREDICT_METHOD_NAME, TRANSFORM_METHOD_NAME from mlem.contrib.numpy import NumpyNdarrayType -from mlem.contrib.sklearn import SklearnModel +from mlem.contrib.scipy import ScipySparceMatrix +from mlem.contrib.sklearn import SklearnModel, SklearnTransformer from mlem.core.artifacts import LOCAL_STORAGE from mlem.core.data_type import DataAnalyzer from mlem.core.model import Argument, ModelAnalyzer @@ -24,6 +26,11 @@ def inp_data(): return [[1, 2, 3], [3, 2, 1]] +@pytest.fixture +def inp_data_text(): + return ["Is that peanut butter on my nose? Mlem!"] + + @pytest.fixture def out_data(): return [1, 2] @@ -43,6 +50,13 @@ def regressor(inp_data, out_data): return lr +@pytest.fixture +def transformer(inp_data_text): + tf_idf = TfidfVectorizer() + tf_idf.fit(inp_data_text) + return tf_idf + + @pytest.fixture() def pipeline(inp_data, out_data): pipe = Pipeline([("scaler", StandardScaler()), ("svc", SVC())]) @@ -76,6 +90,27 @@ def test_hook(model_fixture, inp_data, request): assert signature.returns == returns +def test_hook_transformer(transformer, inp_data_text): + data_type = DataAnalyzer.analyze(inp_data_text) + model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data_text) + assert isinstance(model_type, SklearnTransformer) + assert TRANSFORM_METHOD_NAME in model_type.methods + signature = model_type.methods[TRANSFORM_METHOD_NAME] + returns = ScipySparceMatrix(dtype="float64") + assert signature.name == TRANSFORM_METHOD_NAME + assert signature.args[0] == Argument(name="raw_documents", type_=data_type) + assert signature.returns == returns + + +def test_model_type__transform(transformer, inp_data_text): + model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data_text) + + np.testing.assert_array_almost_equal( + transformer.transform(inp_data_text).todense(), + model_type.call_method("transform", inp_data_text).todense(), + ) + + def test_hook_lgb(lgbm_model, inp_data): data_type = DataAnalyzer.analyze(inp_data) model_type = ModelAnalyzer.analyze(lgbm_model, sample_data=inp_data) From f0e393fd16d4e96b1699a1974d9b640906ebbe00 Mon Sep 17 00:00:00 2001 From: Yury Date: Fri, 23 Dec 2022 11:54:59 +0100 Subject: [PATCH 02/17] fix PR comments --- mlem/contrib/scipy.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index 4d23cecf..612d1311 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -6,7 +6,6 @@ from mlem.core.artifacts import Artifacts, Storage from mlem.core.data_type import ( - DT, DataHook, DataReader, DataType, @@ -39,8 +38,8 @@ def get_writer( class ScipyWriter(DataWriter): def write( - self, data: DT, storage: Storage, path: str - ) -> Tuple[DataReader[DT], Artifacts]: + self, data: DataType, storage: Storage, path: str + ) -> Tuple[DataReader[DataType], Artifacts]: with storage.open(path) as (f, art): sparse.save_npz(f, art) return ScipyReader(data_type=data), {self.art_name: art} @@ -51,7 +50,7 @@ class ScipyReader(DataReader): def read_batch( self, artifacts: Artifacts, batch_size: int - ) -> Iterator[DT]: + ) -> Iterator[DataType]: raise NotImplementedError def read(self, artifacts: Artifacts) -> Iterator[DataType]: From 424ee8edc63e1f7a3fd9aeb294a395fe98aea4b0 Mon Sep 17 00:00:00 2001 From: Yury Date: Fri, 23 Dec 2022 15:06:12 +0100 Subject: [PATCH 03/17] fix PR comments --- mlem/contrib/scipy.py | 2 +- mlem/contrib/sklearn.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index 612d1311..d48b0c35 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -36,7 +36,7 @@ def get_writer( return ScipyWriter(**kwargs) -class ScipyWriter(DataWriter): +class ScipyWriter(DataWriter[[ScipySparceMatrix]]): def write( self, data: DataType, storage: Storage, path: str ) -> Tuple[DataReader[DataType], Artifacts]: diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index 601d0615..af6928eb 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -10,6 +10,7 @@ from sklearn.feature_extraction.text import TransformerMixin, _VectorizerMixin from sklearn.pipeline import Pipeline +from mlem.constants import TRANSFORM_METHOD_NAME from mlem.core.hooks import IsInstanceHookMixin from mlem.core.model import ( ModelHook, @@ -144,10 +145,18 @@ class SklearnTransformer(SklearnModel): @classmethod def process( - cls, obj: Any, sample_data: Optional[Any] = None, **kwargs + cls, + obj: Any, + sample_data: Optional[Any] = None, + methods_sample_data: Optional[Dict[str, Any]] = None, + **kwargs ) -> ModelType: + methods_sample_data = methods_sample_data or {} + sample_data = methods_sample_data.get( + TRANSFORM_METHOD_NAME, sample_data + ) methods = { - "transform": Signature.from_method( + TRANSFORM_METHOD_NAME: Signature.from_method( obj.transform, auto_infer=sample_data is not None, raw_documents=sample_data, From 1f376d5fc1822ea370e339482813cb1fa0113d20 Mon Sep 17 00:00:00 2001 From: Yury Date: Sat, 31 Dec 2022 16:17:58 +0100 Subject: [PATCH 04/17] add tests; add onehotencoder support --- mlem/contrib/scipy.py | 12 ++++---- mlem/contrib/sklearn.py | 10 +++---- tests/contrib/test_scipy.py | 32 +++++++++++++++++++++ tests/contrib/test_sklearn.py | 53 +++++++++++++++++++++++------------ 4 files changed, 77 insertions(+), 30 deletions(-) create mode 100644 tests/contrib/test_scipy.py diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index d48b0c35..97a60f34 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -16,11 +16,11 @@ from mlem.core.requirements import InstallableRequirement, Requirements -class ScipySparceMatrix( +class ScipySparseMatrix( WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin ): type: ClassVar[str] = "csr_matrix" - valid_types: ClassVar = csr_matrix + valid_types: ClassVar = (csr_matrix,) dtype: str def get_requirements(self) -> Requirements: @@ -28,7 +28,7 @@ def get_requirements(self) -> Requirements: @classmethod def process(cls, obj: Any, **kwargs) -> DataType: - return ScipySparceMatrix(dtype=obj.dtype.name) + return ScipySparseMatrix(dtype=obj.dtype.name) def get_writer( self, project: str = None, filename: str = None, **kwargs @@ -36,12 +36,12 @@ def get_writer( return ScipyWriter(**kwargs) -class ScipyWriter(DataWriter[[ScipySparceMatrix]]): +class ScipyWriter(DataWriter[ScipySparseMatrix]): def write( self, data: DataType, storage: Storage, path: str - ) -> Tuple[DataReader[DataType], Artifacts]: + ) -> Tuple[DataReader, Artifacts]: with storage.open(path) as (f, art): - sparse.save_npz(f, art) + sparse.save_npz(f, data.data) return ScipyReader(data_type=data), {self.art_name: art} diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index af6928eb..3e3afc24 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -7,8 +7,9 @@ import sklearn from sklearn.base import ClassifierMixin, RegressorMixin -from sklearn.feature_extraction.text import TransformerMixin, _VectorizerMixin +from sklearn.feature_extraction.text import TransformerMixin from sklearn.pipeline import Pipeline +from sklearn.preprocessing._encoders import _BaseEncoder from mlem.constants import TRANSFORM_METHOD_NAME from mlem.core.hooks import IsInstanceHookMixin @@ -137,10 +138,7 @@ def process( class SklearnTransformer(SklearnModel): - valid_types: ClassVar = ( - TransformerMixin, - _VectorizerMixin, - ) + valid_types: ClassVar = (TransformerMixin, _BaseEncoder) type: ClassVar = "sklearn_transformer" @classmethod @@ -159,7 +157,7 @@ def process( TRANSFORM_METHOD_NAME: Signature.from_method( obj.transform, auto_infer=sample_data is not None, - raw_documents=sample_data, + X=sample_data, ), } diff --git a/tests/contrib/test_scipy.py b/tests/contrib/test_scipy.py new file mode 100644 index 00000000..86efed0e --- /dev/null +++ b/tests/contrib/test_scipy.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest +from scipy.sparse import csr_matrix + +from mlem.contrib.scipy import ScipySparseMatrix +from mlem.core.data_type import DataAnalyzer +from tests.conftest import data_write_read_check + + +@pytest.fixture +def test_data(): + row = np.array([0, 0, 1, 2, 2, 2]) + col = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + return csr_matrix((data, (row, col)), shape=(3, 3), dtype="float32") + + +def test_sparce_matrix(test_data): + assert ScipySparseMatrix.is_object_valid(test_data) + sdt = DataAnalyzer.analyze(test_data) + assert sdt.dict() == {"dtype": "float32", "type": "csr_matrix"} + assert isinstance(sdt, ScipySparseMatrix) + assert sdt.dtype == "float32" + assert sdt.get_requirements().modules == ["scipy"] + + +def test_write_read(test_data): + sdt = DataAnalyzer.analyze(test_data) + sdt = sdt.bind(test_data) + data_write_read_check( + sdt, custom_eq=lambda x, y: np.array_equal(x.todense(), y.todense()) + ) diff --git a/tests/contrib/test_sklearn.py b/tests/contrib/test_sklearn.py index 62a295c1..4245b8d5 100644 --- a/tests/contrib/test_sklearn.py +++ b/tests/contrib/test_sklearn.py @@ -3,15 +3,17 @@ import lightgbm as lgb import numpy as np import pytest -from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.feature_extraction.text import TfidfTransformer from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.svm import SVC from mlem.constants import PREDICT_METHOD_NAME, TRANSFORM_METHOD_NAME from mlem.contrib.numpy import NumpyNdarrayType -from mlem.contrib.scipy import ScipySparceMatrix + +# from mlem.contrib.scipy import ScipySparceMatrix +from mlem.contrib.scipy import ScipySparseMatrix from mlem.contrib.sklearn import SklearnModel, SklearnTransformer from mlem.core.artifacts import LOCAL_STORAGE from mlem.core.data_type import DataAnalyzer @@ -26,9 +28,9 @@ def inp_data(): return [[1, 2, 3], [3, 2, 1]] -@pytest.fixture -def inp_data_text(): - return ["Is that peanut butter on my nose? Mlem!"] +# @pytest.fixture +# def inp_data_text(): +# return ["Is that peanut butter on my nose? Mlem!"] @pytest.fixture @@ -51,12 +53,19 @@ def regressor(inp_data, out_data): @pytest.fixture -def transformer(inp_data_text): - tf_idf = TfidfVectorizer() - tf_idf.fit(inp_data_text) +def transformer(inp_data): + tf_idf = TfidfTransformer() + tf_idf.fit(inp_data) return tf_idf +@pytest.fixture +def onehotencoder(inp_data): + encoder = OneHotEncoder() + encoder.fit(inp_data) + return encoder + + @pytest.fixture() def pipeline(inp_data, out_data): pipe = Pipeline([("scaler", StandardScaler()), ("svc", SVC())]) @@ -90,24 +99,32 @@ def test_hook(model_fixture, inp_data, request): assert signature.returns == returns -def test_hook_transformer(transformer, inp_data_text): - data_type = DataAnalyzer.analyze(inp_data_text) - model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data_text) +@pytest.mark.parametrize( + "transformer_fixture", ["transformer", "onehotencoder"] +) +def test_hook_transformer(transformer_fixture, inp_data, request): + transformer = request.getfixturevalue(transformer_fixture) + data_type = DataAnalyzer.analyze(inp_data) + model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data) assert isinstance(model_type, SklearnTransformer) assert TRANSFORM_METHOD_NAME in model_type.methods signature = model_type.methods[TRANSFORM_METHOD_NAME] - returns = ScipySparceMatrix(dtype="float64") + returns = ScipySparseMatrix(dtype="float64") assert signature.name == TRANSFORM_METHOD_NAME - assert signature.args[0] == Argument(name="raw_documents", type_=data_type) + assert signature.args[0] == Argument(name="X", type_=data_type) assert signature.returns == returns -def test_model_type__transform(transformer, inp_data_text): - model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data_text) +@pytest.mark.parametrize( + "transformer_fixture", ["transformer", "onehotencoder"] +) +def test_model_type__transform(transformer_fixture, inp_data, request): + transformer = request.getfixturevalue(transformer_fixture) + model_type = ModelAnalyzer.analyze(transformer, sample_data=inp_data) np.testing.assert_array_almost_equal( - transformer.transform(inp_data_text).todense(), - model_type.call_method("transform", inp_data_text).todense(), + transformer.transform(inp_data).todense(), + model_type.call_method("transform", inp_data).todense(), ) From 8443036a6b5dda2b9bebfb33321b3a8431bcec48 Mon Sep 17 00:00:00 2001 From: Yury Date: Sat, 31 Dec 2022 17:46:22 +0100 Subject: [PATCH 05/17] add preprocess test --- tests/contrib/test_sklearn.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/contrib/test_sklearn.py b/tests/contrib/test_sklearn.py index 4245b8d5..36425b9e 100644 --- a/tests/contrib/test_sklearn.py +++ b/tests/contrib/test_sklearn.py @@ -9,10 +9,9 @@ from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.svm import SVC +from mlem.api import apply, load, save from mlem.constants import PREDICT_METHOD_NAME, TRANSFORM_METHOD_NAME from mlem.contrib.numpy import NumpyNdarrayType - -# from mlem.contrib.scipy import ScipySparceMatrix from mlem.contrib.scipy import ScipySparseMatrix from mlem.contrib.sklearn import SklearnModel, SklearnTransformer from mlem.core.artifacts import LOCAL_STORAGE @@ -28,11 +27,6 @@ def inp_data(): return [[1, 2, 3], [3, 2, 1]] -# @pytest.fixture -# def inp_data_text(): -# return ["Is that peanut butter on my nose? Mlem!"] - - @pytest.fixture def out_data(): return [1, 2] @@ -128,6 +122,26 @@ def test_model_type__transform(transformer_fixture, inp_data, request): ) +@pytest.mark.parametrize("transformer_fixture", ["transformer"]) +def test_preprocess_transformer( + classifier, transformer_fixture, inp_data, tmpdir, out_data, request +): + transformer = request.getfixturevalue(transformer_fixture) + model_file = "clf" + clf = LogisticRegression() + train_data = transformer.transform(inp_data) + clf.fit(train_data, out_data) + save( + clf, + str(tmpdir / model_file), + sample_data=inp_data, + preprocess=transformer.transform, + ) + clf = load(str(tmpdir / model_file)) + output = apply(clf, inp_data) + assert np.array_equal(output, out_data) + + def test_hook_lgb(lgbm_model, inp_data): data_type = DataAnalyzer.analyze(inp_data) model_type = ModelAnalyzer.analyze(lgbm_model, sample_data=inp_data) From 7da0a0dd350da38a35b960ed491f9ddd42727e20 Mon Sep 17 00:00:00 2001 From: Yury Date: Sun, 1 Jan 2023 14:59:50 +0100 Subject: [PATCH 06/17] add support for all scipy sparse matrices --- mlem/contrib/scipy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index 97a60f34..3128a473 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -2,7 +2,7 @@ import scipy from scipy import sparse -from scipy.sparse import csr_matrix +from scipy.sparse import spmatrix from mlem.core.artifacts import Artifacts, Storage from mlem.core.data_type import ( @@ -20,7 +20,7 @@ class ScipySparseMatrix( WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin ): type: ClassVar[str] = "csr_matrix" - valid_types: ClassVar = (csr_matrix,) + valid_types: ClassVar = (spmatrix,) dtype: str def get_requirements(self) -> Requirements: From 1bd69c7c48409636a66808f65a957bd8aba91a23 Mon Sep 17 00:00:00 2001 From: Yury Date: Mon, 9 Jan 2023 11:30:48 +0100 Subject: [PATCH 07/17] fix tests for onehotencoder --- tests/contrib/test_sklearn.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/contrib/test_sklearn.py b/tests/contrib/test_sklearn.py index 36425b9e..cfb334a8 100644 --- a/tests/contrib/test_sklearn.py +++ b/tests/contrib/test_sklearn.py @@ -9,7 +9,7 @@ from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.svm import SVC -from mlem.api import apply, load, save +from mlem.api import apply, load_meta, save from mlem.constants import PREDICT_METHOD_NAME, TRANSFORM_METHOD_NAME from mlem.contrib.numpy import NumpyNdarrayType from mlem.contrib.scipy import ScipySparseMatrix @@ -122,7 +122,9 @@ def test_model_type__transform(transformer_fixture, inp_data, request): ) -@pytest.mark.parametrize("transformer_fixture", ["transformer"]) +@pytest.mark.parametrize( + "transformer_fixture", ["transformer", "onehotencoder"] +) def test_preprocess_transformer( classifier, transformer_fixture, inp_data, tmpdir, out_data, request ): @@ -135,9 +137,9 @@ def test_preprocess_transformer( clf, str(tmpdir / model_file), sample_data=inp_data, - preprocess=transformer.transform, + preprocess=transformer, ) - clf = load(str(tmpdir / model_file)) + clf = load_meta(str(tmpdir / model_file)) output = apply(clf, inp_data) assert np.array_equal(output, out_data) From 37e4fdc4b29b58f02bd4f37eabaa1ddd880eb6e1 Mon Sep 17 00:00:00 2001 From: Yury Date: Fri, 13 Jan 2023 16:53:30 +0100 Subject: [PATCH 08/17] add serializer and tests for scipy matrices --- mlem/contrib/scipy.py | 91 +++++++++++++++++++++++++++++++++-- mlem/contrib/sklearn.py | 4 ++ tests/contrib/test_scipy.py | 83 ++++++++++++++++++++++++++++---- tests/contrib/test_sklearn.py | 4 +- 4 files changed, 169 insertions(+), 13 deletions(-) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index 3128a473..f8c231ac 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -1,17 +1,31 @@ -from typing import Any, ClassVar, Iterator, Tuple +"""Scipy Sparse matrices support +Extension type: data + +DataType, Reader and Writer implementations for `scipy.sparse` +""" +from typing import ClassVar, Iterator, List, Optional, Tuple, Type, Union import scipy +from pydantic import BaseModel +from pydantic.main import create_model +from pydantic.types import conlist from scipy import sparse from scipy.sparse import spmatrix +from mlem.contrib.numpy import ( + np_type_from_string, + python_type_from_np_string_repr, +) from mlem.core.artifacts import Artifacts, Storage from mlem.core.data_type import ( DataHook, DataReader, + DataSerializer, DataType, DataWriter, WithDefaultSerializer, ) +from mlem.core.errors import DeserializationError, SerializationError from mlem.core.hooks import IsInstanceHookMixin from mlem.core.requirements import InstallableRequirement, Requirements @@ -19,24 +33,44 @@ class ScipySparseMatrix( WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin ): + """ + DataType implementation for scipy sparse matrix + """ + type: ClassVar[str] = "csr_matrix" valid_types: ClassVar = (spmatrix,) + shape: Optional[Tuple] + """shape of `sparse.csr_matrix` object in data""" dtype: str + """dtype of `sparse.csr_matrix` object in data""" def get_requirements(self) -> Requirements: return Requirements.new([InstallableRequirement.from_module(scipy)]) @classmethod - def process(cls, obj: Any, **kwargs) -> DataType: - return ScipySparseMatrix(dtype=obj.dtype.name) + def process(cls, obj: sparse.csr_matrix, **kwargs) -> DataType: + return ScipySparseMatrix(dtype=obj.dtype.name, shape=obj.shape) def get_writer( self, project: str = None, filename: str = None, **kwargs ) -> DataWriter: return ScipyWriter(**kwargs) + def subtype(self, subshape: Tuple[Optional[int], ...]): + if len(subshape) == 0: + return python_type_from_np_string_repr(self.dtype) + return conlist( + self.subtype(subshape[1:]), + min_items=subshape[0], + max_items=subshape[0], + ) + class ScipyWriter(DataWriter[ScipySparseMatrix]): + """ + write scipy matrix to npz format + """ + def write( self, data: DataType, storage: Storage, path: str ) -> Tuple[DataReader, Artifacts]: @@ -46,6 +80,10 @@ def write( class ScipyReader(DataReader): + """ + read scipy matrix from npz format + """ + type: ClassVar[str] = "csr_matrix" def read_batch( @@ -61,3 +99,50 @@ def read(self, artifacts: Artifacts) -> Iterator[DataType]: with artifacts[DataWriter.art_name].open() as f: data = sparse.load_npz(f) return self.data_type.copy().bind(data) + + +class ScipySparseMatrixSerializer(DataSerializer[ScipySparseMatrix]): + """ + serializer for scipy sparse matrices + """ + + is_default: ClassVar = True + data_class: ClassVar = ScipySparseMatrix + + def get_model( + self, data_type: ScipySparseMatrix, prefix: str = "" + ) -> Union[Type[BaseModel], type]: + item_type = List[data_type.subtype(data_type.shape[1:])] # type: ignore[index] + return create_model( + prefix + "ScipySparse", + __root__=(item_type, ...), + ) + + def serialize(self, data_type: ScipySparseMatrix, instance: spmatrix): + data_type.check_type(instance, sparse.csr_matrix, SerializationError) + if instance.dtype != np_type_from_string(data_type.dtype): + raise SerializationError( + f"given matrix is of dtype: {instance.dtype}, " + f"expected: {data_type.dtype}" + ) + coordinate_matrix = instance.tocoo() + data = coordinate_matrix.data + row = coordinate_matrix.row + col = coordinate_matrix.col + return data, (row, col) + + def deserialize(self, data_type, obj) -> sparse.csr_matrix: + + try: + mat = sparse.csr_matrix( + obj, dtype=data_type.dtype, shape=data_type.shape + ) + except ValueError as e: + raise DeserializationError( + f"Given object {obj} could not be converted" + f"to sparse matrix of type: {data_type.type}" + ) from e + return mat + + # def get_model(self, data_type: DT, prefix: str = "") -> Union[Type[BaseModel], type]: + # pass diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index 3e3afc24..99a81102 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -138,6 +138,10 @@ def process( class SklearnTransformer(SklearnModel): + """ + Model Type implementation for sklearn transformers + """ + valid_types: ClassVar = (TransformerMixin, _BaseEncoder) type: ClassVar = "sklearn_transformer" diff --git a/tests/contrib/test_scipy.py b/tests/contrib/test_scipy.py index 86efed0e..436a511a 100644 --- a/tests/contrib/test_scipy.py +++ b/tests/contrib/test_scipy.py @@ -4,29 +4,94 @@ from mlem.contrib.scipy import ScipySparseMatrix from mlem.core.data_type import DataAnalyzer +from mlem.core.errors import DeserializationError, SerializationError from tests.conftest import data_write_read_check @pytest.fixture -def test_data(): +def raw_data(): row = np.array([0, 0, 1, 2, 2, 2]) col = np.array([0, 2, 2, 0, 1, 2]) data = np.array([1, 2, 3, 4, 5, 6]) - return csr_matrix((data, (row, col)), shape=(3, 3), dtype="float32") + return data, (row, col) -def test_sparce_matrix(test_data): - assert ScipySparseMatrix.is_object_valid(test_data) - sdt = DataAnalyzer.analyze(test_data) - assert sdt.dict() == {"dtype": "float32", "type": "csr_matrix"} +@pytest.fixture +def sparse_mat(raw_data): + return csr_matrix(raw_data, shape=(3, 3), dtype="float32") + + +@pytest.fixture +def schema(): + return { + "title": "ScipySparse", + "type": "array", + "items": { + "type": "array", + "items": {"type": "number"}, + "minItems": 3, + "maxItems": 3, + }, + } + + +@pytest.fixture +def sparse_data_type(sparse_mat): + return DataAnalyzer.analyze(sparse_mat) + + +def test_sparce_matrix(sparse_mat, schema): + assert ScipySparseMatrix.is_object_valid(sparse_mat) + sdt = DataAnalyzer.analyze(sparse_mat) + assert sdt.dict() == { + "dtype": "float32", + "type": "csr_matrix", + "shape": (3, 3), + } + model = sdt.get_model() + assert model.__name__ == "ScipySparse" + assert model.schema() == schema assert isinstance(sdt, ScipySparseMatrix) assert sdt.dtype == "float32" assert sdt.get_requirements().modules == ["scipy"] -def test_write_read(test_data): - sdt = DataAnalyzer.analyze(test_data) - sdt = sdt.bind(test_data) +def test_serialization(raw_data, sparse_mat): + sdt = DataAnalyzer.analyze(sparse_mat) + payload = sdt.serialize(sparse_mat) + deserialized_data = sdt.deserialize(payload) + assert np.array_equal(sparse_mat.todense(), deserialized_data.todense()) + + +def test_write_read(sparse_mat): + sdt = DataAnalyzer.analyze(sparse_mat) + sdt = sdt.bind(sparse_mat) data_write_read_check( sdt, custom_eq=lambda x, y: np.array_equal(x.todense(), y.todense()) ) + + +@pytest.mark.parametrize( + "obj", + [ + 1, # wrong type + csr_matrix( + ([1], ([1], [0])), shape=(3, 3), dtype="float64" + ), # wrong dtype + csr_matrix( + ([1], ([1], [0])), shape=(2, 2), dtype="float32" + ), # wrong shape + ], +) +def test_serialize_failure(sparse_mat, obj): + sdt = DataAnalyzer.analyze(sparse_mat) + with pytest.raises(SerializationError): + sdt.serialize(obj) + + +@pytest.mark.parametrize( + "obj", [1, ([1, 1], ([0, 6], [1, 6]))] # wrong type # wrong shape +) +def test_desiarilze_failure(sparse_data_type, obj): + with pytest.raises(DeserializationError): + sparse_data_type.deserialize(obj) diff --git a/tests/contrib/test_sklearn.py b/tests/contrib/test_sklearn.py index cfb334a8..a7105108 100644 --- a/tests/contrib/test_sklearn.py +++ b/tests/contrib/test_sklearn.py @@ -103,7 +103,9 @@ def test_hook_transformer(transformer_fixture, inp_data, request): assert isinstance(model_type, SklearnTransformer) assert TRANSFORM_METHOD_NAME in model_type.methods signature = model_type.methods[TRANSFORM_METHOD_NAME] - returns = ScipySparseMatrix(dtype="float64") + cols = len(transformer.get_feature_names_out()) + rows = len(inp_data) + returns = ScipySparseMatrix(dtype="float64", shape=(rows, cols)) assert signature.name == TRANSFORM_METHOD_NAME assert signature.args[0] == Argument(name="X", type_=data_type) assert signature.returns == returns From c4ab73595a2b188468a28c29813a160692d17abf Mon Sep 17 00:00:00 2001 From: Yury Date: Fri, 13 Jan 2023 17:46:35 +0100 Subject: [PATCH 09/17] fix bug with not using call_order when using apply function --- mlem/api/commands.py | 13 ++++++++----- mlem/contrib/scipy.py | 10 +++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mlem/api/commands.py b/mlem/api/commands.py index a97cf097..49b507e7 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -32,6 +32,7 @@ MlemLink, MlemModel, MlemObject, + _ModelMethodCall, ) from mlem.runtime.client import Client from mlem.runtime.interface import ModelInterface @@ -85,18 +86,20 @@ def apply( except WrongMethodError: resolved_method = PREDICT_METHOD_NAME echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...") + method_call = _ModelMethodCall( + name=resolved_method, + order=model.call_orders[resolved_method], + model=model, + ) if batch_size: res: Any = [] for part in data: batch_data = get_data_value(part, batch_size) for batch in batch_data: - preds = w.call_method(resolved_method, batch.data) + preds = method_call(batch.data) res += [*preds] # TODO: merge results else: - res = [ - w.call_method(resolved_method, get_data_value(part)) - for part in data - ] + res = [method_call(get_data_value(part)) for part in data] if output is None: if len(res) == 1: return res[0] diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index f8c231ac..fee2edd6 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -40,9 +40,9 @@ class ScipySparseMatrix( type: ClassVar[str] = "csr_matrix" valid_types: ClassVar = (spmatrix,) shape: Optional[Tuple] - """shape of `sparse.csr_matrix` object in data""" + """Shape of `sparse.csr_matrix` object in data""" dtype: str - """dtype of `sparse.csr_matrix` object in data""" + """Dtype of `sparse.csr_matrix` object in data""" def get_requirements(self) -> Requirements: return Requirements.new([InstallableRequirement.from_module(scipy)]) @@ -68,7 +68,7 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): class ScipyWriter(DataWriter[ScipySparseMatrix]): """ - write scipy matrix to npz format + Write scipy matrix to npz format """ def write( @@ -81,7 +81,7 @@ def write( class ScipyReader(DataReader): """ - read scipy matrix from npz format + Read scipy matrix from npz format """ type: ClassVar[str] = "csr_matrix" @@ -103,7 +103,7 @@ def read(self, artifacts: Artifacts) -> Iterator[DataType]: class ScipySparseMatrixSerializer(DataSerializer[ScipySparseMatrix]): """ - serializer for scipy sparse matrices + Serializer for scipy sparse matrices """ is_default: ClassVar = True From 036a990e9d4377329be782290f4c1b4fc8d7bebb Mon Sep 17 00:00:00 2001 From: Yury Date: Sat, 14 Jan 2023 16:48:42 +0100 Subject: [PATCH 10/17] fix tests --- mlem/contrib/scipy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index fee2edd6..eb62c02b 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -71,6 +71,8 @@ class ScipyWriter(DataWriter[ScipySparseMatrix]): Write scipy matrix to npz format """ + type: ClassVar[str] = "csr_matrix" + def write( self, data: DataType, storage: Storage, path: str ) -> Tuple[DataReader, Artifacts]: From 068ac6517d2828531777414078629d566a55a8e1 Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Mon, 16 Jan 2023 13:05:10 +0600 Subject: [PATCH 11/17] add scipy to extras --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 8a6386f7..54486e3e 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ "pandas": ["pandas"], "numpy": ["numpy"], "sklearn": ["scikit-learn"], + "scipy": ["scipy"], "onnx": ["onnx"], "onnxruntime": [ "protobuf==3.20.1", From 0f5f4fea4998ebf60126c78df9de40851bfa449f Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Mon, 16 Jan 2023 13:05:39 +0600 Subject: [PATCH 12/17] fix typo --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 54486e3e..a75c868e 100644 --- a/setup.py +++ b/setup.py @@ -238,7 +238,7 @@ "serializer.xgboost_dmatrix = mlem.contrib.xgboost:DMatrixSerializer", "model_type.xgboost = mlem.contrib.xgboost:XGBoostModel", "model_io.xgboost_io = mlem.contrib.xgboost:XGBoostModelIO", - "data_type.csr_matrix = mlem.contrib.scipy:ScipySparceMatrix", + "data_type.csr_matrix = mlem.contrib.scipy:ScipySparseMatrix", "data_writer.csr_matrix = mlem.contrib.scipy:ScipyWriter", "data_reader.csr_matrix = mlem.contrib.scipy:ScipyReader", ], From f3b13dd730d7e79079e9a64dc0ea7cf4946997db Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Mon, 16 Jan 2023 13:24:46 +0600 Subject: [PATCH 13/17] mute flake8 --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index de26369f..e8146186 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ ignore = P1, # unindexed parameters in the str.format, see: B902, # Invalid first argument 'cls' used for instance method. B024, # ABCs without methods + B028, # Use f"{obj!r}" instead of f"'{obj}'" # https://pypi.org/project/flake8-string-format/ max_line_length = 79 max-complexity = 15 From c10e0bfc029e626ffd6930ec1750f72edb830e70 Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Mon, 16 Jan 2023 13:34:56 +0600 Subject: [PATCH 14/17] add missing entrypoints to setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index a75c868e..e5e56a9e 100644 --- a/setup.py +++ b/setup.py @@ -241,6 +241,8 @@ "data_type.csr_matrix = mlem.contrib.scipy:ScipySparseMatrix", "data_writer.csr_matrix = mlem.contrib.scipy:ScipyWriter", "data_reader.csr_matrix = mlem.contrib.scipy:ScipyReader", + "model_type.sklearn_transformer = mlem.contrib.sklearn:SklearnTransformer", + "serializer.csr_matrix = mlem.contrib.scipy:ScipySparseMatrixSerializer", ], "mlem.config": [ "core = mlem.config:MlemConfig", From e72e8e39105298ced29d8a4dc76a44911fa30cb3 Mon Sep 17 00:00:00 2001 From: Yury Date: Mon, 16 Jan 2023 10:34:48 +0100 Subject: [PATCH 15/17] remove old code --- mlem/contrib/scipy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index eb62c02b..4d247bd9 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -145,6 +145,3 @@ def deserialize(self, data_type, obj) -> sparse.csr_matrix: f"to sparse matrix of type: {data_type.type}" ) from e return mat - - # def get_model(self, data_type: DT, prefix: str = "") -> Union[Type[BaseModel], type]: - # pass From 104839efff7dc0f620fc0f03ba6744aec4b1efee Mon Sep 17 00:00:00 2001 From: Yury Date: Wed, 18 Jan 2023 17:31:41 +0100 Subject: [PATCH 16/17] fix failing test --- mlem/contrib/scipy.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index 4d247bd9..a14f53a5 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -65,6 +65,22 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): max_items=subshape[0], ) + def check_shape(self, array, exc_type): + if self.shape is not None: + if len(array.shape) != len(self.shape): + raise exc_type( + f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" + ) + + array_shape = tuple( + None if expected_dim is None else array_dim + for array_dim, expected_dim in zip(array.shape, self.shape) + ) + if tuple(array_shape) != self.shape: + raise exc_type( + f"given array is of shape: {array_shape}, expected: {self.shape}" + ) + class ScipyWriter(DataWriter[ScipySparseMatrix]): """ @@ -127,6 +143,7 @@ def serialize(self, data_type: ScipySparseMatrix, instance: spmatrix): f"given matrix is of dtype: {instance.dtype}, " f"expected: {data_type.dtype}" ) + data_type.check_shape(instance, SerializationError) coordinate_matrix = instance.tocoo() data = coordinate_matrix.data row = coordinate_matrix.row @@ -144,4 +161,5 @@ def deserialize(self, data_type, obj) -> sparse.csr_matrix: f"Given object {obj} could not be converted" f"to sparse matrix of type: {data_type.type}" ) from e + data_type.check_shape(mat, DeserializationError) return mat From ac3f23f599e0cd6a2bad1e72c5f11514cd338785 Mon Sep 17 00:00:00 2001 From: Yury Date: Wed, 18 Jan 2023 17:57:10 +0100 Subject: [PATCH 17/17] fix failing test --- mlem/contrib/numpy.py | 37 +++++++++++++++++++------------------ mlem/contrib/scipy.py | 21 +++------------------ 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 5861b11f..74c96976 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -44,6 +44,23 @@ def np_type_from_string(string_repr) -> np.dtype: raise ValueError(f"Unknown numpy type {string_repr}") from e +def check_shape(shape, array, exc_type): + if shape is not None: + if len(array.shape) != len(shape): + raise exc_type( + f"given array is of rank: {len(array.shape)}, expected: {len(shape)}" + ) + + array_shape = tuple( + None if expected_dim is None else array_dim + for array_dim, expected_dim in zip(array.shape, shape) + ) + if tuple(array_shape) != shape: + raise exc_type( + f"given array is of shape: {array_shape}, expected: {shape}" + ) + + class NumpyNumberType( WithDefaultSerializer, LibRequirementsMixin, DataType, DataHook ): @@ -123,22 +140,6 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): max_items=subshape[0], ) - def check_shape(self, array, exc_type): - if self.shape is not None: - if len(array.shape) != len(self.shape): - raise exc_type( - f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" - ) - - array_shape = tuple( - None if expected_dim is None else array_dim - for array_dim, expected_dim in zip(array.shape, self.shape) - ) - if tuple(array_shape) != self.shape: - raise exc_type( - f"given array is of shape: {array_shape}, expected: {self.shape}" - ) - def get_writer(self, project: str = None, filename: str = None, **kwargs): return NumpyArrayWriter() @@ -171,7 +172,7 @@ def deserialize(self, data_type, obj): f"given object: {obj} could not be converted to array " f"of type: {np_type_from_string(data_type.dtype)}" ) from e - data_type.check_shape(ret, DeserializationError) + check_shape(data_type.shape, ret, DeserializationError) return ret def serialize(self, data_type, instance: np.ndarray): @@ -181,7 +182,7 @@ def serialize(self, data_type, instance: np.ndarray): raise SerializationError( f"given array is of type: {instance.dtype}, expected: {exp_type}" ) - data_type.check_shape(instance, SerializationError) + check_shape(data_type.shape, instance, SerializationError) return instance.tolist() diff --git a/mlem/contrib/scipy.py b/mlem/contrib/scipy.py index a14f53a5..75c95336 100644 --- a/mlem/contrib/scipy.py +++ b/mlem/contrib/scipy.py @@ -13,6 +13,7 @@ from scipy.sparse import spmatrix from mlem.contrib.numpy import ( + check_shape, np_type_from_string, python_type_from_np_string_repr, ) @@ -65,22 +66,6 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): max_items=subshape[0], ) - def check_shape(self, array, exc_type): - if self.shape is not None: - if len(array.shape) != len(self.shape): - raise exc_type( - f"given array is of rank: {len(array.shape)}, expected: {len(self.shape)}" - ) - - array_shape = tuple( - None if expected_dim is None else array_dim - for array_dim, expected_dim in zip(array.shape, self.shape) - ) - if tuple(array_shape) != self.shape: - raise exc_type( - f"given array is of shape: {array_shape}, expected: {self.shape}" - ) - class ScipyWriter(DataWriter[ScipySparseMatrix]): """ @@ -143,7 +128,7 @@ def serialize(self, data_type: ScipySparseMatrix, instance: spmatrix): f"given matrix is of dtype: {instance.dtype}, " f"expected: {data_type.dtype}" ) - data_type.check_shape(instance, SerializationError) + check_shape(data_type.shape, instance, SerializationError) coordinate_matrix = instance.tocoo() data = coordinate_matrix.data row = coordinate_matrix.row @@ -161,5 +146,5 @@ def deserialize(self, data_type, obj) -> sparse.csr_matrix: f"Given object {obj} could not be converted" f"to sparse matrix of type: {data_type.type}" ) from e - data_type.check_shape(mat, DeserializationError) + check_shape(data_type.shape, mat, DeserializationError) return mat