Skip to content

Commit

Permalink
Merge pull request #50 from transferwise/tree_depth
Browse files Browse the repository at this point in the history
Correctly handle max_depth in tree solver
  • Loading branch information
AlxdrPolyakov authored May 13, 2024
2 parents 53d8cde + 023ef37 commit 2d9dcd9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
1 change: 1 addition & 0 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def fit(
dims=dims,
time_basis=self.time_basis,
num_leaves=max_segments,
max_depth=max_depth,
)
self.nonzeros = np.array(range(self.X.shape[1]))
Xw = csc_matrix(diags(self.weights) @ self.X)
Expand Down
39 changes: 23 additions & 16 deletions wise_pizza/solve/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def tree_solver(
dim_df: pd.DataFrame,
dims: List[str],
time_basis: Optional[pd.DataFrame] = None,
max_depth: int = 3,
max_depth: Optional[int] = None,
num_leaves: Optional[int] = None,
):
if time_basis is None:
Expand All @@ -27,7 +27,7 @@ def tree_solver(
df["__avg"] = df["totals"] / df["weights"]
df["__avg"] = df["__avg"].fillna(df["__avg"].mean())

root = ModelNode(df=df, fitter=fitter, dims=dims)
root = ModelNode(df=df, fitter=fitter, dims=dims, max_depth=max_depth)

build_tree(root=root, num_leaves=num_leaves, max_depth=max_depth)

Expand Down Expand Up @@ -66,19 +66,23 @@ def __init__(
df: pd.DataFrame,
fitter: Fitter,
dims: List[str],
max_depth: Optional[int] = None,
dim_split: Optional[Dict[str, List]] = None,
depth: int = 0,
):
self.df = df.copy()
self.fitter = fitter
self.dims = dims
self.max_depth = max_depth
self._best_submodels = None
self._error_improvement = float("-inf")
self.children = None
self.dim_split = dim_split or {}
self.depth = depth
self.model = None

@property
def depth(self):
return len(self.dim_split)

@property
def error(self):
if self.model is None:
Expand All @@ -94,9 +98,19 @@ def error(self):

@property
def error_improvement(self):
if self.max_depth is None:
self.max_depth = float("inf")
if self._best_submodels is None:
best_error = float("inf")
for dim in self.dims:

if self.depth > self.max_depth:
raise ValueError("Max depth exceeded")
elif self.depth == self.max_depth:
iter_dims = list(self.dim_split.keys())
else:
iter_dims = self.dims

for dim in iter_dims:
if len(self.df[dim].unique()) == 1:
continue
enc_map = target_encode(self.df, dim)
Expand All @@ -122,14 +136,14 @@ def error_improvement(self):
fitter=self.fitter,
dims=self.dims,
dim_split={**self.dim_split, **{dim: dim_values1}},
depth=self.depth + 1,
max_depth=self.max_depth,
)
right_candidate = ModelNode(
df=right,
fitter=self.fitter,
dims=self.dims,
dim_split={**self.dim_split, **{dim: dim_values2}},
depth=self.depth + 1,
max_depth=self.max_depth,
)

err = left_candidate.error + right_candidate.error
Expand All @@ -141,13 +155,6 @@ def error_improvement(self):
return self._error_improvement


def mod_improvement(improvement: float, depth: int, max_depth: int) -> float:
if depth < max_depth:
return improvement
else:
return float("-inf")


def get_best_subtree_result(
node: ModelNode, max_depth: Optional[int] = 1000
) -> ModelNode:
Expand All @@ -156,8 +163,8 @@ def get_best_subtree_result(
else:
node1 = get_best_subtree_result(node.children[0])
node2 = get_best_subtree_result(node.children[1])
improvement1 = mod_improvement(node1.error_improvement, node1.depth, max_depth)
improvement2 = mod_improvement(node2.error_improvement, node2.depth, max_depth)
improvement1 = node1.error_improvement
improvement2 = node2.error_improvement
if improvement1 > improvement2:
return node1
else:
Expand Down

0 comments on commit 2d9dcd9

Please sign in to comment.