Skip to content

Commit

Permalink
Add sample data parameter to Metadata, add humanize for readable wind…
Browse files Browse the repository at this point in the history
…ow size (#178)

* added humanize

* added humanize test
  • Loading branch information
gsheni authored Dec 26, 2023
1 parent 61cefb0 commit 41539cb
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ dependencies = [
"scipy >= 1.10.0",
"tqdm >= 4.65.0",
"importlib_resources >= 6.0.0",
"pyarrow >= 11.0.0"
"pyarrow >= 14.0.1",
"humanize >= 4.9.0"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion tests/minimal_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pandas==2.0.0
scipy==1.10.0
tqdm==4.65.0
importlib_resources==6.0.0
pyarrow==11.0.0
pyarrow==14.0.1
34 changes: 34 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,40 @@ def test_add_relationships_new_table(multitable_metadata):
)


def test_sample_data_single():
dataframe, ml_types, _, primary_key, time_index = generate_mock_data(
tables=["products"],
)
dataframe = dataframe["products"]
primary_key = "key does not exist"
match = "does not exist in sample data's columns"
with pytest.raises(ValueError, match=match):
SingleTableMetadata(
ml_types=ml_types,
primary_key=primary_key,
time_index=time_index,
sample_data=dataframe,
)


def test_sample_data_multi():
dataframes, ml_types, relationships, primary_keys, time_indices = (
generate_mock_data(
tables=["products", "logs"],
)
)
primary_keys["products"] = "key does not exist"
match = "does not exist in sample data's table"
with pytest.raises(ValueError, match=match):
MultiTableMetadata(
ml_types=ml_types,
primary_keys=primary_keys,
time_indices=time_indices,
relationships=relationships,
sample_data=dataframes,
)


def verify_ml_types(metadata, expected_ml_types):
if metadata.get_metadata_type() == "single":
for key in expected_ml_types:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_problem_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def test_problem_generator_multi(tables, target_table):
num_columns = len(problems[0].metadata.ml_types.keys())
print(f"generated {len(problems)} problems from {num_columns} columns")
for p in tqdm(problems):
string_repr = p.__repr__()
assert "2 days" in string_repr
if p.has_parameters_set() is True:
labels = p.create_target_values(dataframes)
check_problem_type(labels, p.get_problem_type())
Expand Down
7 changes: 5 additions & 2 deletions trane/core/problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import humanize
import pandas as pd

from trane.core.utils import calculate_target_values
Expand Down Expand Up @@ -188,8 +189,10 @@ def __str__(self):
description += filter_op.generate_description()

if self.window_size:
description += " " + "in next {} days".format(
self.window_size,
window_size = pd.to_timedelta(self.window_size)
human_readble = humanize.naturaldelta(window_size)
description += " " + "in next {}".format(
human_readble,
)
return description

Expand Down
2 changes: 1 addition & 1 deletion trane/metadata/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from trane.metadata.metadata import *
from trane.metadata.metadata import SingleTableMetadata, MultiTableMetadata
26 changes: 22 additions & 4 deletions trane/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def __init__(
primary_key: str = None,
time_index: str = None,
original_multi_table_metadata=None,
sample_data: pd.DataFrame = None,
):
self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type())
self.sample_data = sample_data
self.primary_key = None
if primary_key:
self.set_primary_key(primary_key)
Expand All @@ -53,7 +55,14 @@ def __repr__(self):
return result

def set_primary_key(self, primary_key):
if primary_key not in self.ml_types:
if (
self.sample_data is not None
and primary_key not in self.sample_data.columns.tolist()
):
raise ValueError(
f"Index {primary_key} does not exist in sample data's columns",
)
elif primary_key not in self.ml_types:
raise ValueError("Index does not exist in ml_types")
elif (
self.primary_key
Expand Down Expand Up @@ -109,11 +118,13 @@ class MultiTableMetadata(BaseMetadata):
def __init__(
self,
ml_types: dict = None,
primary_keys=None,
time_indices=None,
primary_keys: dict = None,
time_indices: dict = None,
relationships: list = None,
sample_data: pd.DataFrame = None,
):
self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type())
self.sample_data = sample_data
self.primary_keys = {}
if primary_keys:
self.set_primary_keys(primary_keys)
Expand Down Expand Up @@ -200,7 +211,14 @@ def check_if_table_exists(self, table):
raise ValueError(f"Table: {table} does not exist")

def check_if_column_exists(self, table, column):
if column not in self.ml_types[table]:
if (
self.sample_data is not None
and column not in self.sample_data[table].columns
):
raise ValueError(
f"Column: {column} does not exist in sample data's table: {table}",
)
elif column not in self.ml_types[table]:
raise ValueError(f"Column: {column} does not exist in Table: {table}")

def remove_table(self, table):
Expand Down

0 comments on commit 41539cb

Please sign in to comment.