Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in tree solver that was causing wrong fits #49

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fit(
@param min_segments: Minimum number of segments to find
@param max_segments: Maximum number of segments to find, defaults to min_segments
@param min_depth: Minimum number of dimension to constrain in segment definition
@param max_depth: Maximum number of dimension to constrain in segment definition
@param max_depth: Maximum number of dimensions to constrain in segment definition; also max depth pf tree in tree solver
@param solver: Valid values are "lasso" (default), "tree" (for non-overlapping segments), "omp", or "lp"
@param verbose: If set to a truish value, lots of debug info is printed to console
@param force_dim: To add dim
Expand Down Expand Up @@ -287,7 +287,8 @@ def fit(
# assert wgt == wgts[i]
s["orig_i"] = i
s["coef"] = self.reg.coef_[i]
s["impact"] = np.abs(s["coef"]) * (np.abs(this_vec) * self.weights).sum()
# TODO: does not taking the abs of coef here break time series?
s["impact"] = s["coef"] * (np.abs(this_vec) * self.weights).sum()
s["avg_impact"] = s["impact"] / sum(self.weights)
s["total"] = (self.totals * dummy).sum()
s["seg_size"] = wgt
Expand Down
6 changes: 4 additions & 2 deletions wise_pizza/solve/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def fit_predict(self, X, y, sample_weight=None):
return self.predict(X)

def error(self, X, y, sample_weight=None):
# Error is chosen so that it's minimized by the weighted mean of y
err = y - self.predict(X)
errsq = err**2
if sample_weight is not None:
err *= sample_weight
return np.nansum(err**2)
errsq *= sample_weight
return np.nansum(errsq)


class AverageFitter(Fitter):
Expand Down
4 changes: 2 additions & 2 deletions wise_pizza/solve/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
dim_split: Optional[Dict[str, List]] = None,
depth: int = 0,
):
self.df = df
self.df = df.copy()
self.fitter = fitter
self.dims = dims
self._best_submodels = None
Expand All @@ -85,7 +85,7 @@ def error(self):
self.model = copy.deepcopy(self.fitter)
self.model.fit(
X=self.df[self.dims],
y=self.df["totals"],
y=self.df["__avg"],
sample_weight=self.df["weights"],
)
return self.model.error(
Expand Down
Loading