Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Sep 19, 2023
1 parent 3163eb7 commit 0a29bd4
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tests/ops/test_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def test_get_aggregation_ops():
assert issubclass(instance, AggregationOpBase)
assert get_aggregation_ops() == [
CountAggregationOp,
ExistsAggregationOp,
SumAggregationOp,
AvgAggregationOp,
MaxAggregationOp,
MinAggregationOp,
MajorityAggregationOp,
ExistsAggregationOp,
FirstAggregationOp,
LastAggregationOp,
]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_denormalize_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_denormalize_two_tables():
flat = flat.set_index("id").sort_values("id")
assert flat["product_id"].tolist() == [1, 2, 3, 1, 2]
assert flat["session_id"].tolist() == [1, 1, 2, 2, 2]
assert flat["products.price"].tolist() == [10, 20, 30, 10, 20]
assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25]


def test_denormalize_two_tables_change():
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_denormalize_three_tables():
flat = flat.set_index("id").sort_values("id")
assert flat["product_id"].tolist() == [1, 2, 3, 1, 2]
assert flat["session_id"].tolist() == [1, 1, 2, 2, 2]
assert flat["products.price"].tolist() == [10, 20, 30, 10, 20]
assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25]
assert flat["sessions.customer_id"].tolist() == [0, 0, 0, 0, 0]


Expand Down Expand Up @@ -162,7 +162,7 @@ def test_denormalize_four_tables():
assert flat["id"].is_unique
assert sorted(flat["id"].tolist()) == [1, 2, 3, 4, 5]
flat = flat.set_index("id").sort_values("id")
assert flat["products.price"].tolist() == [10, 20, 30, 10, 20]
assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25]
assert flat["sessions.customer_id"].tolist() == [0, 0, 0, 0, 0]
assert flat["sessions.customers.age"].tolist() == [33, 33, 33, 33, 33]
assert flat["sessions.customers.région_id"].tolist() == ["United States"] * 5
Expand Down
12 changes: 6 additions & 6 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def test_set_primary_key(single_metadata):
single_metadata.set_primary_key("column_4")


def test_set_time_index(single_metadata):
single_metadata.reset_time_index()
def test_reset_time_key(single_metadata):
single_metadata.reset_time_key()
assert single_metadata.time_index is None
single_metadata.set_time_index("purchase_date")
single_metadata.set_time_key("purchase_date")
match = "Time index must be of type Datetime"
with pytest.raises(ValueError, match=match):
single_metadata.set_time_index("card_type")
single_metadata.set_time_key("card_type")


def test_from_dataframe_single():
Expand Down Expand Up @@ -174,9 +174,9 @@ def test_set_time_index_multi(multitable_metadata):
},
)
multitable_metadata.set_primary_key("products", "column_9")
multitable_metadata.set_time_index("products", "column_10")
multitable_metadata.set_time_key("products", "column_10")
with pytest.raises(ValueError):
multitable_metadata.set_time_index("products", "column_9")
multitable_metadata.set_time_key("products", "column_9")


def test_set_type_multi(multitable_metadata):
Expand Down
2 changes: 1 addition & 1 deletion trane/core/problem_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ProblemGenerator:
def __init__(
self,
metadata,
window_size,
window_size=None,
target_table: str = None,
entity_columns: List[str] = None,
):
Expand Down
42 changes: 31 additions & 11 deletions trane/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ def __init__(self):
def set_primary_key(self):
raise NotImplementedError

def set_time_index(self):
def set_time_key(self):
raise NotImplementedError

def get_ml_type(self):
def get_type(self):
raise NotImplementedError

def set_type(self):
raise NotImplementedError

def set_types(self):
raise NotImplementedError


Expand All @@ -32,7 +38,7 @@ def __init__(
self.set_primary_key(primary_key)
self.time_index = None
if time_index:
self.set_time_index(time_index)
self.set_time_key(time_index)
self.original_multi_table_metadata = original_multi_table_metadata

def set_primary_key(self, primary_key):
Expand All @@ -51,18 +57,18 @@ def reset_primary_key(self):
self.ml_types[self.primary_key].remove_tag("primary_key")
self.primary_key = None

def set_time_index(self, time_index):
def set_time_key(self, time_index):
if time_index not in self.ml_types:
raise ValueError("Time index does not exist in ml_types")
elif self.time_index and self.ml_types[self.time_index].has_tag("time_index"):
self.ml_types[self.time_index].remove_tag("time_index")

if time_index and not isinstance(self.get_ml_type(time_index), Datetime):
if time_index and not isinstance(self.get_type(time_index), Datetime):
raise ValueError("Time index must be of type Datetime")
self.time_index = time_index
self.ml_types[time_index].tags.add("time_index")

def reset_time_index(self):
def reset_time_key(self):
if self.time_index:
self.ml_types[self.time_index].remove_tag("time_index")
self.time_index = None
Expand All @@ -72,7 +78,10 @@ def set_type(self, column, ml_type, tags=set()):
ml_type.tags = tags
self.ml_types[column] = ml_type

def get_ml_type(self, column):
def set_types(self, ml_types: dict):
self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type())

Check warning on line 82 in trane/metadata/metadata.py

View check run for this annotation

Codecov / codecov/patch

trane/metadata/metadata.py#L82

Added line #L82 was not covered by tests

def get_type(self, column):
return self.ml_types[column]

@staticmethod
Expand Down Expand Up @@ -110,11 +119,11 @@ def get_metadata_type():

def set_time_indices(self, time_indices):
for table, time_index_column in time_indices.items():
self.set_time_index(table, time_index_column)
self.set_time_key(table, time_index_column)

def set_time_index(self, table, column):
def set_time_key(self, table, column):
self.check_if_table_exists(table)
if not isinstance(self.get_ml_type(table, column), Datetime):
if not isinstance(self.get_type(table, column), Datetime):
raise ValueError("Time index must be of type Datetime")
self.time_indices[table] = column
self.ml_types[table][column] = Datetime()
Expand Down Expand Up @@ -145,14 +154,17 @@ def add_table(self, table, ml_types):
type_="single",
)

def get_ml_type(self, table, column):
def get_type(self, table, column):
self.check_if_table_exists(table)
return self.ml_types[table][column]

def set_type(self, table, column, ml_type):
ml_type = check_ml_type(ml_type)
self.ml_types[table][column] = ml_type

def set_types(self, table, ml_types):
self.ml_types[table] = _parse_ml_types(ml_types, type_="single")

Check warning on line 166 in trane/metadata/metadata.py

View check run for this annotation

Codecov / codecov/patch

trane/metadata/metadata.py#L166

Added line #L166 was not covered by tests

def add_relationships(self, relationships):
if not isinstance(relationships, list):
relationships = [relationships]
Expand All @@ -161,6 +173,14 @@ def add_relationships(self, relationships):
if rel not in self.relationships:
self.relationships.append(rel)

parent_table_name, parent_key, _, _ = rel
if (
"primary_key"
not in self.ml_types[parent_table_name][parent_key].get_tags()
):
self.ml_types[parent_table_name][parent_key].tags.add("primary_key")
self.primary_keys[parent_table_name] = parent_key

def check_if_table_exists(self, table):
if table not in self.ml_types:
raise ValueError(f"Table: {table} does not exist")
Expand Down

0 comments on commit 0a29bd4

Please sign in to comment.