diff --git a/wise_pizza/slicer.py b/wise_pizza/slicer.py index 9f53c08..de0b267 100644 --- a/wise_pizza/slicer.py +++ b/wise_pizza/slicer.py @@ -195,6 +195,7 @@ def fit( dims=dims, time_basis=self.time_basis, num_leaves=max_segments, + max_depth=max_depth, ) self.nonzeros = np.array(range(self.X.shape[1])) Xw = csc_matrix(diags(self.weights) @ self.X) diff --git a/wise_pizza/solve/tree.py b/wise_pizza/solve/tree.py index 0a1504e..e279722 100644 --- a/wise_pizza/solve/tree.py +++ b/wise_pizza/solve/tree.py @@ -14,7 +14,7 @@ def tree_solver( dim_df: pd.DataFrame, dims: List[str], time_basis: Optional[pd.DataFrame] = None, - max_depth: int = 3, + max_depth: Optional[int] = None, num_leaves: Optional[int] = None, ): if time_basis is None: @@ -27,7 +27,7 @@ def tree_solver( df["__avg"] = df["totals"] / df["weights"] df["__avg"] = df["__avg"].fillna(df["__avg"].mean()) - root = ModelNode(df=df, fitter=fitter, dims=dims) + root = ModelNode(df=df, fitter=fitter, dims=dims, max_depth=max_depth) build_tree(root=root, num_leaves=num_leaves, max_depth=max_depth) @@ -66,19 +66,23 @@ def __init__( df: pd.DataFrame, fitter: Fitter, dims: List[str], + max_depth: Optional[int] = None, dim_split: Optional[Dict[str, List]] = None, - depth: int = 0, ): self.df = df.copy() self.fitter = fitter self.dims = dims + self.max_depth = max_depth self._best_submodels = None self._error_improvement = float("-inf") self.children = None self.dim_split = dim_split or {} - self.depth = depth self.model = None + @property + def depth(self): + return len(self.dim_split) + @property def error(self): if self.model is None: @@ -94,9 +98,19 @@ def error(self): @property def error_improvement(self): + if self.max_depth is None: + self.max_depth = float("inf") if self._best_submodels is None: best_error = float("inf") - for dim in self.dims: + + if self.depth > self.max_depth: + raise ValueError("Max depth exceeded") + elif self.depth == self.max_depth: + iter_dims = list(self.dim_split.keys()) + else: + iter_dims = self.dims + + for dim in iter_dims: if len(self.df[dim].unique()) == 1: continue enc_map = target_encode(self.df, dim) @@ -122,14 +136,14 @@ def error_improvement(self): fitter=self.fitter, dims=self.dims, dim_split={**self.dim_split, **{dim: dim_values1}}, - depth=self.depth + 1, + max_depth=self.max_depth, ) right_candidate = ModelNode( df=right, fitter=self.fitter, dims=self.dims, dim_split={**self.dim_split, **{dim: dim_values2}}, - depth=self.depth + 1, + max_depth=self.max_depth, ) err = left_candidate.error + right_candidate.error @@ -141,13 +155,6 @@ def error_improvement(self): return self._error_improvement -def mod_improvement(improvement: float, depth: int, max_depth: int) -> float: - if depth < max_depth: - return improvement - else: - return float("-inf") - - def get_best_subtree_result( node: ModelNode, max_depth: Optional[int] = 1000 ) -> ModelNode: @@ -156,8 +163,8 @@ def get_best_subtree_result( else: node1 = get_best_subtree_result(node.children[0]) node2 = get_best_subtree_result(node.children[1]) - improvement1 = mod_improvement(node1.error_improvement, node1.depth, max_depth) - improvement2 = mod_improvement(node2.error_improvement, node2.depth, max_depth) + improvement1 = node1.error_improvement + improvement2 = node2.error_improvement if improvement1 > improvement2: return node1 else: