Skip to content

Commit

Permalink
cluster values fix, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Polyakov committed May 13, 2024
1 parent 2d9dcd9 commit 2f95f4f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 36 deletions.
34 changes: 17 additions & 17 deletions notebooks/Finding interesting segments.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
(explain_changes_in_average, explain_changes_in_totals), # function
(0.0, 90.0), # nan_percent
(0.0, 90.0), # size_one_percent
(True, False), # cluster_values
]
# possible variants for explain methods
deltas_test_cases = list(itertools.product(*deltas_test_values))
Expand Down Expand Up @@ -243,7 +244,7 @@ def test_synthetic_ts_template(nan_percent: float):


@pytest.mark.parametrize(
"how, solver, plot_is_static, function, nan_percent, size_one_percent",
"how, solver, plot_is_static, function, nan_percent, size_one_percent, cluster_values",
deltas_test_cases,
)
def test_deltas(
Expand All @@ -253,6 +254,7 @@ def test_deltas(
function: Callable,
nan_percent: float,
size_one_percent: float,
cluster_values: bool,
):
all_data = monthly_driver_data()
df = all_data.data
Expand All @@ -277,6 +279,7 @@ def test_deltas(
max_depth=1,
min_segments=10,
solver=solver,
cluster_values=cluster_values
)
# sf.plot(plot_is_static=plot_is_static)
print("yay!")
Expand Down
3 changes: 0 additions & 3 deletions wise_pizza/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def explain_changes_in_totals(
plot_is_static=plot_is_static,
width=width,
height=height,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width,
return_fig=return_fig,
Expand Down Expand Up @@ -258,7 +257,6 @@ def explain_changes_in_totals(
plot_is_static=plot_is_static,
width=width,
height=height,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width,
return_fig=return_fig,
Expand Down Expand Up @@ -341,7 +339,6 @@ def explain_levels(
width=width,
height=height,
return_fig=return_fig,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width,
)
Expand Down
27 changes: 12 additions & 15 deletions wise_pizza/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def plot_split_segments(
plot_is_static: bool = False,
width: int = 2000,
height: int = 500,
cluster_values: bool = False,
cluster_key_width: int = 180,
cluster_value_width: int = 318,
return_fig: bool = False,
Expand Down Expand Up @@ -113,7 +112,7 @@ def plot_split_segments(
for i in range(1, 3):
fig.update_yaxes(autorange="reversed", row=i)

if cluster_values:
if sf_size.relevant_cluster_names:
data_dict = sf_size.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
Expand All @@ -139,7 +138,7 @@ def plot_split_segments(
# Convert the figure to a static image
image_bytes = to_image(fig, format="png", scale=2)

if cluster_values:
if sf_size.relevant_cluster_names:
display(
Image(
image_bytes,
Expand All @@ -158,12 +157,12 @@ def plot_split_segments(
)
else:
if return_fig:
if cluster_values:
if sf_size.relevant_cluster_names:
return [fig, fig2]
else:
return fig
fig.show()
if cluster_values:
if sf_size.relevant_cluster_names:
fig2.show()


Expand All @@ -173,7 +172,6 @@ def plot_segments(
width: int = 2000,
height: int = 500,
return_fig: bool = False,
cluster_values: bool = False,
cluster_key_width: int = 180,
cluster_value_width: int = 318,
):
Expand Down Expand Up @@ -244,7 +242,7 @@ def plot_segments(
annotation_text="Global average",
)

if cluster_values:
if sf.relevant_cluster_names:
data_dict = sf.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
Expand All @@ -271,7 +269,7 @@ def plot_segments(
image_bytes = to_image(fig, format="png", scale=2)

# Display the static image in the Jupyter notebook
if cluster_values:
if sf.relevant_cluster_names:
image_bytes2 = to_image(fig2, format="png", scale=2)
display(
Image(
Expand All @@ -289,13 +287,13 @@ def plot_segments(
)
else:
if return_fig:
if cluster_values:
if sf.relevant_cluster_names:
return [fig, fig2]
else:
return fig
else:
fig.show()
if cluster_values:
if sf.relevant_cluster_names:
fig2.show()


Expand Down Expand Up @@ -360,7 +358,6 @@ def plot_waterfall(
plot_is_static: bool = False,
width: int = 1000,
height: int = 1000,
cluster_values: bool = False,
cluster_key_width: int = 180,
cluster_value_width: int = 318,
return_fig: bool = False,
Expand All @@ -387,7 +384,7 @@ def plot_waterfall(
**waterfall_layout_args(sf, width, height),
)

if cluster_values:
if sf.relevant_cluster_names:
data_dict = sf.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
Expand All @@ -412,18 +409,18 @@ def plot_waterfall(
if plot_is_static:
# Convert the figure to a static image
image_bytes = to_image(fig, format="png", scale=2)
if cluster_values:
if sf.relevant_cluster_names:
display(Image(image_bytes, height=height, width=width))
fig2.show()
else:
# Display the static image in the Jupyter notebook
display(Image(image_bytes, width=width, height=height))
else:
if return_fig:
if cluster_values:
if sf.relevant_cluster_names:
return [fig, fig2]
else:
return fig
fig.show()
if cluster_values:
if sf.relevant_cluster_names:
fig2.show()

0 comments on commit 2f95f4f

Please sign in to comment.