diff --git a/wise_pizza/solve/tree.py b/wise_pizza/solve/tree.py index e279722..22fe2b0 100644 --- a/wise_pizza/solve/tree.py +++ b/wise_pizza/solve/tree.py @@ -78,6 +78,8 @@ def __init__( self.children = None self.dim_split = dim_split or {} self.model = None + # For dimension splitting candidates, hardwired for now + self.num_bins = 10 @property def depth(self): @@ -118,7 +120,7 @@ def error_improvement(self): if np.any(np.isnan(self.df[dim + "_encoded"])): # pragma: no cover raise ValueError("NaNs in encoded values") # Get split candidates for brute force search - deciles = np.array([q / 10.0 for q in range(1, 10)]) + deciles = np.array([q / self.num_bins for q in range(1, self.num_bins)]) splits = weighted_quantiles( self.df[dim + "_encoded"], deciles, self.df["weights"]