Skip to content

Commit

Permalink
added relevant_cluster_names plot, added info to the readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Polyakov committed Dec 6, 2023
1 parent 445509b commit 8400ec3
Show file tree
Hide file tree
Showing 6 changed files with 478 additions and 267 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ sf1 = explain_changes_in_average(

![plot](https://github.com/transferwise/wise-pizza/blob/main/docs/explain_changes_in_average(totals).png?raw=True)

***In addition to single-value slices, consider slices that consist of a
group of segments from the same dimension with similar naive averages***
For that goal you can use cluster_values=True parameter.

![plot](https://github.com/transferwise/wise-pizza/blob/main/docs/cluster_values.png?raw=True)

And then you can visualize differences:

```Python
Expand All @@ -132,6 +138,12 @@ And check segments:
```Python
sf.segments
```

if you use cluster values, you can also check relevant cluster names:
```Python
sf.relevant_cluster_names
```

Please see the full example [here](https://github.com/transferwise/wise-pizza/blob/main/notebooks/Finding%20interesting%20segments.ipynb)

## For Developers
Expand Down
Binary file added docs/cluster_values.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
191 changes: 68 additions & 123 deletions notebooks/Finding interesting segments (continuous segments).ipynb

Large diffs are not rendered by default.

343 changes: 240 additions & 103 deletions notebooks/Finding interesting segments.ipynb

Large diffs are not rendered by default.

32 changes: 25 additions & 7 deletions wise_pizza/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def explain_changes_in_average(
to the difference between dataset totals
@param constrain_signs: Whether to constrain weights of segments to have the same
sign as naive segment averages
@param cluster_values: In addition to single-value slices, consider slices that consist of a
group of segments from the same dimension with similar naive averages
@param verbose: If set to a truish value, lots of debug info is printed to console
@return: A fitted object
"""
Expand Down Expand Up @@ -143,7 +145,7 @@ def explain_changes_in_totals(
to the difference between dataset totals
@param constrain_signs: Whether to constrain weights of segments to have the same
sign as naive segment averages
@param cluster_values In addition to single-value slices, consider slices that consist of a
@param cluster_values: In addition to single-value slices, consider slices that consist of a
group of segments from the same dimension with similar naive averages
@param verbose: If set to a truish value, lots of debug info is printed to console
@return: A fitted object
Expand Down Expand Up @@ -207,12 +209,15 @@ def explain_changes_in_totals(
sf_size.final_size = final_size
sf_avg.final_size = final_size
sp = SlicerPair(sf_size, sf_avg)
sp.plot = lambda plot_is_static=False, width=2000, height=500: plot_split_segments(
sp.plot = lambda plot_is_static=False, width=2000, height=500, cluster_key_width=180, cluster_value_width=318: plot_split_segments(
sp.s1,
sp.s2,
plot_is_static=plot_is_static,
width=width,
height=height,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width
)
return sp

Expand All @@ -238,8 +243,14 @@ def explain_changes_in_totals(
sf.pre_total = df1[total_name].sum()
sf.post_total = df2[total_name].sum()

sf.plot = lambda plot_is_static=False, width=1000, height=1000: plot_waterfall(
sf, plot_is_static=plot_is_static, width=width, height=height
sf.plot = lambda plot_is_static=False, width=1000, height=1000, cluster_key_width=180, cluster_value_width=318: plot_waterfall(
sf,
plot_is_static=plot_is_static,
width=width,
height=height,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width
)
sf.task = "changes in totals"
return sf
Expand Down Expand Up @@ -274,7 +285,7 @@ def explain_levels(
@param verbose: If set to a truish value, lots of debug info is printed to console
@param force_add_up: Force the contributions of chosen segments to add up to zero
@param constrain_signs: Whether to constrain weights of segments to have the same sign as naive segment averages
@param cluster_values In addition to single-value slices, consider slices that consist of a
@param cluster_values: In addition to single-value slices, consider slices that consist of a
group of segments from the same dimension with similar naive averages
@return: A fitted object
"""
Expand Down Expand Up @@ -313,8 +324,15 @@ def explain_levels(
s["total"] += average * s["seg_size"]
# print(average)
sf.reg.intercept_ = average
sf.plot = lambda plot_is_static=False, width=2000, height=500, return_fig=False: plot_segments(
sf, plot_is_static=plot_is_static, width=width, height=height, return_fig=return_fig
sf.plot = lambda plot_is_static=False, width=2000, height=500, return_fig=False, cluster_key_width=180, cluster_value_width=318: plot_segments(
sf,
plot_is_static=plot_is_static,
width=width,
height=height,
return_fig=return_fig,
cluster_values=cluster_values,
cluster_key_width=cluster_key_width,
cluster_value_width=cluster_value_width
)
sf.task = "levels"
return sf
167 changes: 133 additions & 34 deletions wise_pizza/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def plot_split_segments(
plot_is_static: bool = False,
width: int = 2000,
height: int = 500,
cluster_values: bool=False,
cluster_key_width: int = 180,
cluster_value_width: int = 318
):
"""
Plot split segments for explain_changes: split_fits
Expand Down Expand Up @@ -112,26 +115,66 @@ def plot_split_segments(

for i in range(1, 3):
fig.update_yaxes(autorange="reversed", row=i)

if cluster_values:
data_dict = sf_size.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
key_column_width = cluster_key_width # Adjust the multiplier as needed
value_column_width = cluster_value_width # Adjust the multiplier as needed

# Create a table trace with specified column widths
table_trace = go.Table(
header=dict(values=['Cluster', 'Segments']),
cells=dict(values=[keys, values]),
columnwidth=[key_column_width, value_column_width]
)

# Create a layout
layout = go.Layout(title='Relevant cluster names',
title_x=0 # Center the title
)

# Create a figure
fig2 = go.Figure(data=[table_trace], layout=layout)

if plot_is_static:
# Convert the figure to a static image
image_bytes = to_image(fig, format="png", scale=2)

# Display the static image in the Jupyter notebook
return Image(
image_bytes,
height=height + len(size_data.index) * 30,
width=width + len(size_data.index) * 30,
)
if cluster_values:
image_bytes2 = to_image(fig2, format="png", scale=2)
display(
Image(
image_bytes,
height=height + len(size_data.index) * 30,
width=width + len(size_data.index) * 30,
)
)
fig2.show()

else:
# Display the static image in the Jupyter notebook
return Image(
image_bytes,
height=height + len(size_data.index) * 30,
width=width + len(size_data.index) * 30,
)
else:
fig.show()
if cluster_values:
fig2.show()


def plot_segments(
sf: SliceFinder,
plot_is_static: bool = False,
width: int = 2000,
height: int = 500,
return_fig: bool = False
return_fig: bool = False,
cluster_values: bool=False,
cluster_key_width: int = 180,
cluster_value_width: int = 318
):
"""
Plot segments for explain_levels
Expand Down Expand Up @@ -200,22 +243,62 @@ def plot_segments(
annotation_text="Global average",
)

if cluster_values:
data_dict = sf.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
key_column_width = cluster_key_width # Adjust the multiplier as needed
value_column_width = cluster_value_width # Adjust the multiplier as needed

# Create a table trace with specified column widths
table_trace = go.Table(
header=dict(values=['Cluster', 'Segments']),
cells=dict(values=[keys, values]),
columnwidth=[key_column_width, value_column_width]
)

# Create a layout
layout = go.Layout(title='Relevant cluster names',
title_x=0 # Center the title
)

# Create a figure
fig2 = go.Figure(data=[table_trace], layout=layout)

if plot_is_static:
# Convert the figure to a static image
image_bytes = to_image(fig, format="png", scale=2)

# Display the static image in the Jupyter notebook
return Image(
image_bytes,
height=height + len(sf.segment_labels) * 30,
width=width + len(sf.segment_labels) * 30,
)
if cluster_values:
image_bytes2 = to_image(fig2, format="png", scale=2)
display(
Image(
image_bytes,
height=height + len(sf.segment_labels) * 30,
width=width + len(sf.segment_labels) * 30,
)
)
display(
Image(
image_bytes2,
height=height,
width=width
)
)
else:
return Image(
image_bytes,
height=height + len(sf.segment_labels) * 30,
width=width + len(sf.segment_labels) * 30,
)
else:
if return_fig:
return fig
else:
fig.show()

if cluster_values:
fig2.show()

def waterfall_args(sf: SliceFinder):
"""
Expand Down Expand Up @@ -274,7 +357,13 @@ def waterfall_layout_args(sf: SliceFinder, width: int = 1000, height: int = 1000


def plot_waterfall(
sf: SliceFinder, plot_is_static: bool = False, width: int = 1000, height: int = 1000
sf: SliceFinder,
plot_is_static: bool = False,
width: int = 1000,
height: int = 1000,
cluster_values: bool=False,
cluster_key_width: int = 180,
cluster_value_width: int = 318
):
"""
Plot waterfall and Bar for explain_changes
Expand All @@ -287,39 +376,49 @@ def plot_waterfall(
data = pd.DataFrame(sf.segments, index=np.array(sf.segment_labels))
trace1 = go.Waterfall(name="Segments waterfall", **waterfall_args(sf))

trace2 = go.Bar(
x=data["naive_avg"],
y=data.index,
orientation="h",
name="Diff in averages",
marker_color="#ff685f",
)

fig = go.Figure()
fig2 = go.Figure()

fig.add_trace(trace1)
fig.update_layout(title="Segments contributing most to the change")
fig2.add_trace(trace2)
fig2["layout"]["yaxis"].update(autorange="reversed")
fig2.update_layout(title="Segment averages")

fig2.update_layout(width=width, height=height)

fig.update_layout(
title="Segments contributing most to the change",
# showlegend = True,
**waterfall_layout_args(sf, width, height)
)

if cluster_values:
data_dict = sf.relevant_cluster_names
keys = list(data_dict.keys())
values = list(data_dict.values())
key_column_width = cluster_key_width # Adjust the multiplier as needed
value_column_width = cluster_value_width # Adjust the multiplier as needed

# Create a table trace with specified column widths
table_trace = go.Table(
header=dict(values=['Cluster', 'Segments']),
cells=dict(values=[keys, values]),
columnwidth=[key_column_width, value_column_width]
)

# Create a layout
layout = go.Layout(title='Relevant cluster names',
title_x=0 # Center the title
)

# Create a figure
fig2 = go.Figure(data=[table_trace], layout=layout)

if plot_is_static:
# Convert the figure to a static image
image_bytes = to_image(fig, format="png", scale=2)
image_bytes2 = to_image(fig2, format="png", scale=2)

# Display the static image in the Jupyter notebook
display(Image(image_bytes, width=width, height=height))
display(Image(image_bytes2, width=width, height=height))
if cluster_values:
display(Image(image_bytes, height=height, width=width))
fig2.show()
else:
# Display the static image in the Jupyter notebook
display(Image(image_bytes, width=width, height=height))
else:
fig.show()
fig2.show()
if cluster_values:
fig2.show()

0 comments on commit 8400ec3

Please sign in to comment.