diff --git a/pyproject.toml b/pyproject.toml index 8d88732..b0473ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ "scipy >= 1.10.0", "tqdm >= 4.65.0", "importlib_resources >= 6.0.0", - "pyarrow >= 11.0.0" + "pyarrow >= 14.0.1", + "humanize >= 4.9.0" ] [project.urls] diff --git a/tests/minimal_requirements.txt b/tests/minimal_requirements.txt index 7d35850..3838010 100644 --- a/tests/minimal_requirements.txt +++ b/tests/minimal_requirements.txt @@ -3,4 +3,4 @@ pandas==2.0.0 scipy==1.10.0 tqdm==4.65.0 importlib_resources==6.0.0 -pyarrow==11.0.0 +pyarrow==14.0.1 diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 9321ee4..8c19dcc 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -238,6 +238,40 @@ def test_add_relationships_new_table(multitable_metadata): ) +def test_sample_data_single(): + dataframe, ml_types, _, primary_key, time_index = generate_mock_data( + tables=["products"], + ) + dataframe = dataframe["products"] + primary_key = "key does not exist" + match = "does not exist in sample data's columns" + with pytest.raises(ValueError, match=match): + SingleTableMetadata( + ml_types=ml_types, + primary_key=primary_key, + time_index=time_index, + sample_data=dataframe, + ) + + +def test_sample_data_multi(): + dataframes, ml_types, relationships, primary_keys, time_indices = ( + generate_mock_data( + tables=["products", "logs"], + ) + ) + primary_keys["products"] = "key does not exist" + match = "does not exist in sample data's table" + with pytest.raises(ValueError, match=match): + MultiTableMetadata( + ml_types=ml_types, + primary_keys=primary_keys, + time_indices=time_indices, + relationships=relationships, + sample_data=dataframes, + ) + + def verify_ml_types(metadata, expected_ml_types): if metadata.get_metadata_type() == "single": for key in expected_ml_types: diff --git a/tests/test_problem_generator.py b/tests/test_problem_generator.py index 4e07856..7d27929 100644 --- a/tests/test_problem_generator.py +++ b/tests/test_problem_generator.py @@ -81,6 +81,8 @@ def test_problem_generator_multi(tables, target_table): num_columns = len(problems[0].metadata.ml_types.keys()) print(f"generated {len(problems)} problems from {num_columns} columns") for p in tqdm(problems): + string_repr = p.__repr__() + assert "2 days" in string_repr if p.has_parameters_set() is True: labels = p.create_target_values(dataframes) check_problem_type(labels, p.get_problem_type()) diff --git a/trane/core/problem.py b/trane/core/problem.py index dff7d07..da40570 100644 --- a/trane/core/problem.py +++ b/trane/core/problem.py @@ -1,3 +1,4 @@ +import humanize import pandas as pd from trane.core.utils import calculate_target_values @@ -188,8 +189,10 @@ def __str__(self): description += filter_op.generate_description() if self.window_size: - description += " " + "in next {} days".format( - self.window_size, + window_size = pd.to_timedelta(self.window_size) + human_readble = humanize.naturaldelta(window_size) + description += " " + "in next {}".format( + human_readble, ) return description diff --git a/trane/metadata/__init__.py b/trane/metadata/__init__.py index fe335db..b2a0ae3 100644 --- a/trane/metadata/__init__.py +++ b/trane/metadata/__init__.py @@ -1 +1 @@ -from trane.metadata.metadata import * +from trane.metadata.metadata import SingleTableMetadata, MultiTableMetadata diff --git a/trane/metadata/metadata.py b/trane/metadata/metadata.py index 1c924d4..1a29950 100644 --- a/trane/metadata/metadata.py +++ b/trane/metadata/metadata.py @@ -33,8 +33,10 @@ def __init__( primary_key: str = None, time_index: str = None, original_multi_table_metadata=None, + sample_data: pd.DataFrame = None, ): self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type()) + self.sample_data = sample_data self.primary_key = None if primary_key: self.set_primary_key(primary_key) @@ -53,7 +55,14 @@ def __repr__(self): return result def set_primary_key(self, primary_key): - if primary_key not in self.ml_types: + if ( + self.sample_data is not None + and primary_key not in self.sample_data.columns.tolist() + ): + raise ValueError( + f"Index {primary_key} does not exist in sample data's columns", + ) + elif primary_key not in self.ml_types: raise ValueError("Index does not exist in ml_types") elif ( self.primary_key @@ -109,11 +118,13 @@ class MultiTableMetadata(BaseMetadata): def __init__( self, ml_types: dict = None, - primary_keys=None, - time_indices=None, + primary_keys: dict = None, + time_indices: dict = None, relationships: list = None, + sample_data: pd.DataFrame = None, ): self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type()) + self.sample_data = sample_data self.primary_keys = {} if primary_keys: self.set_primary_keys(primary_keys) @@ -200,7 +211,14 @@ def check_if_table_exists(self, table): raise ValueError(f"Table: {table} does not exist") def check_if_column_exists(self, table, column): - if column not in self.ml_types[table]: + if ( + self.sample_data is not None + and column not in self.sample_data[table].columns + ): + raise ValueError( + f"Column: {column} does not exist in sample data's table: {table}", + ) + elif column not in self.ml_types[table]: raise ValueError(f"Column: {column} does not exist in Table: {table}") def remove_table(self, table):