diff --git a/tests/ops/test_op_utils.py b/tests/ops/test_op_utils.py index 0174f9a..9893539 100644 --- a/tests/ops/test_op_utils.py +++ b/tests/ops/test_op_utils.py @@ -29,12 +29,12 @@ def test_get_aggregation_ops(): assert issubclass(instance, AggregationOpBase) assert get_aggregation_ops() == [ CountAggregationOp, + ExistsAggregationOp, SumAggregationOp, AvgAggregationOp, MaxAggregationOp, MinAggregationOp, MajorityAggregationOp, - ExistsAggregationOp, FirstAggregationOp, LastAggregationOp, ] diff --git a/tests/test_denormalize_dataframes.py b/tests/test_denormalize_dataframes.py index 084e389..65be635 100755 --- a/tests/test_denormalize_dataframes.py +++ b/tests/test_denormalize_dataframes.py @@ -53,7 +53,7 @@ def test_denormalize_two_tables(): flat = flat.set_index("id").sort_values("id") assert flat["product_id"].tolist() == [1, 2, 3, 1, 2] assert flat["session_id"].tolist() == [1, 1, 2, 2, 2] - assert flat["products.price"].tolist() == [10, 20, 30, 10, 20] + assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25] def test_denormalize_two_tables_change(): @@ -125,7 +125,7 @@ def test_denormalize_three_tables(): flat = flat.set_index("id").sort_values("id") assert flat["product_id"].tolist() == [1, 2, 3, 1, 2] assert flat["session_id"].tolist() == [1, 1, 2, 2, 2] - assert flat["products.price"].tolist() == [10, 20, 30, 10, 20] + assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25] assert flat["sessions.customer_id"].tolist() == [0, 0, 0, 0, 0] @@ -162,7 +162,7 @@ def test_denormalize_four_tables(): assert flat["id"].is_unique assert sorted(flat["id"].tolist()) == [1, 2, 3, 4, 5] flat = flat.set_index("id").sort_values("id") - assert flat["products.price"].tolist() == [10, 20, 30, 10, 20] + assert flat["products.price"].tolist() == [10.5, 20.25, 30.01, 10.5, 20.25] assert flat["sessions.customer_id"].tolist() == [0, 0, 0, 0, 0] assert flat["sessions.customers.age"].tolist() == [33, 33, 33, 33, 33] assert flat["sessions.customers.région_id"].tolist() == ["United States"] * 5 diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 17ee818..3a3d0d5 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -69,13 +69,13 @@ def test_set_primary_key(single_metadata): single_metadata.set_primary_key("column_4") -def test_set_time_index(single_metadata): - single_metadata.reset_time_index() +def test_reset_time_key(single_metadata): + single_metadata.reset_time_key() assert single_metadata.time_index is None - single_metadata.set_time_index("purchase_date") + single_metadata.set_time_key("purchase_date") match = "Time index must be of type Datetime" with pytest.raises(ValueError, match=match): - single_metadata.set_time_index("card_type") + single_metadata.set_time_key("card_type") def test_from_dataframe_single(): @@ -174,9 +174,9 @@ def test_set_time_index_multi(multitable_metadata): }, ) multitable_metadata.set_primary_key("products", "column_9") - multitable_metadata.set_time_index("products", "column_10") + multitable_metadata.set_time_key("products", "column_10") with pytest.raises(ValueError): - multitable_metadata.set_time_index("products", "column_9") + multitable_metadata.set_time_key("products", "column_9") def test_set_type_multi(multitable_metadata): diff --git a/trane/core/problem_generator.py b/trane/core/problem_generator.py index 5bdd7b1..f49812a 100644 --- a/trane/core/problem_generator.py +++ b/trane/core/problem_generator.py @@ -28,7 +28,7 @@ class ProblemGenerator: def __init__( self, metadata, - window_size, + window_size=None, target_table: str = None, entity_columns: List[str] = None, ): diff --git a/trane/metadata/metadata.py b/trane/metadata/metadata.py index 0a842fa..4447e98 100644 --- a/trane/metadata/metadata.py +++ b/trane/metadata/metadata.py @@ -11,10 +11,16 @@ def __init__(self): def set_primary_key(self): raise NotImplementedError - def set_time_index(self): + def set_time_key(self): raise NotImplementedError - def get_ml_type(self): + def get_type(self): + raise NotImplementedError + + def set_type(self): + raise NotImplementedError + + def set_types(self): raise NotImplementedError @@ -32,7 +38,7 @@ def __init__( self.set_primary_key(primary_key) self.time_index = None if time_index: - self.set_time_index(time_index) + self.set_time_key(time_index) self.original_multi_table_metadata = original_multi_table_metadata def set_primary_key(self, primary_key): @@ -51,18 +57,18 @@ def reset_primary_key(self): self.ml_types[self.primary_key].remove_tag("primary_key") self.primary_key = None - def set_time_index(self, time_index): + def set_time_key(self, time_index): if time_index not in self.ml_types: raise ValueError("Time index does not exist in ml_types") elif self.time_index and self.ml_types[self.time_index].has_tag("time_index"): self.ml_types[self.time_index].remove_tag("time_index") - if time_index and not isinstance(self.get_ml_type(time_index), Datetime): + if time_index and not isinstance(self.get_type(time_index), Datetime): raise ValueError("Time index must be of type Datetime") self.time_index = time_index self.ml_types[time_index].tags.add("time_index") - def reset_time_index(self): + def reset_time_key(self): if self.time_index: self.ml_types[self.time_index].remove_tag("time_index") self.time_index = None @@ -72,7 +78,10 @@ def set_type(self, column, ml_type, tags=set()): ml_type.tags = tags self.ml_types[column] = ml_type - def get_ml_type(self, column): + def set_types(self, ml_types: dict): + self.ml_types = _parse_ml_types(ml_types, type_=self.get_metadata_type()) + + def get_type(self, column): return self.ml_types[column] @staticmethod @@ -110,11 +119,11 @@ def get_metadata_type(): def set_time_indices(self, time_indices): for table, time_index_column in time_indices.items(): - self.set_time_index(table, time_index_column) + self.set_time_key(table, time_index_column) - def set_time_index(self, table, column): + def set_time_key(self, table, column): self.check_if_table_exists(table) - if not isinstance(self.get_ml_type(table, column), Datetime): + if not isinstance(self.get_type(table, column), Datetime): raise ValueError("Time index must be of type Datetime") self.time_indices[table] = column self.ml_types[table][column] = Datetime() @@ -145,7 +154,7 @@ def add_table(self, table, ml_types): type_="single", ) - def get_ml_type(self, table, column): + def get_type(self, table, column): self.check_if_table_exists(table) return self.ml_types[table][column] @@ -153,6 +162,9 @@ def set_type(self, table, column, ml_type): ml_type = check_ml_type(ml_type) self.ml_types[table][column] = ml_type + def set_types(self, table, ml_types): + self.ml_types[table] = _parse_ml_types(ml_types, type_="single") + def add_relationships(self, relationships): if not isinstance(relationships, list): relationships = [relationships] @@ -161,6 +173,14 @@ def add_relationships(self, relationships): if rel not in self.relationships: self.relationships.append(rel) + parent_table_name, parent_key, _, _ = rel + if ( + "primary_key" + not in self.ml_types[parent_table_name][parent_key].get_tags() + ): + self.ml_types[parent_table_name][parent_key].tags.add("primary_key") + self.primary_keys[parent_table_name] = parent_key + def check_if_table_exists(self, table): if table not in self.ml_types: raise ValueError(f"Table: {table} does not exist")