diff --git a/tests/test_fit.py b/tests/test_fit.py index 1ba9802..4bb8eaf 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -139,7 +139,7 @@ def test_time_series_tree_solver(fit_sizes: bool): sf = explain_timeseries( df=data.data, dims=data.dimensions, - max_segments=7, + num_segments=7, max_depth=2, total_name=data.segment_total, size_name=data.segment_size, diff --git a/wise_pizza/solve/tree.py b/wise_pizza/solve/tree.py index 639fb21..a2eaf47 100644 --- a/wise_pizza/solve/tree.py +++ b/wise_pizza/solve/tree.py @@ -34,6 +34,8 @@ def tree_solver( """ df = dim_df.copy().reset_index(drop=True) + if "total_adjustment" not in df.columns: + df["total_adjustment"] = 0.0 df["totals"] -= df["total_adjustment"] df["__avg"] = df["totals"] / df["weights"] df["__avg"] = df["__avg"].fillna(df["__avg"].mean())