Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: option to disable legend #20

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions qiskit_addon_obp/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .metadata import OBPMetadata


def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
def plot_accumulated_error(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
boonware marked this conversation as resolved.
Show resolved Hide resolved
"""Plot the accumulated error.

This method populates the provided figure axes with a line-plot of the
Expand Down Expand Up @@ -72,6 +72,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
if not np.isinf(metadata.truncation_error_budget.max_error_total):
axes.axhline(
Expand All @@ -93,10 +94,12 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
)
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("accumulated error")
axes.legend()
_set_legend(axes, show_legend)


def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
def plot_left_over_error_budget(
metadata: OBPMetadata, axes: Axes, show_legend: bool = True
) -> None:
"""Plot the left-over error budget.

This method populates the provided figure axes with a line-plot of the
Expand Down Expand Up @@ -127,6 +130,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)):
axes.plot(
Expand All @@ -139,10 +143,10 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
)
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("left-over error budget")
axes.legend()
_set_legend(axes, show_legend)


def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
def plot_slice_errors(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
"""Plot the slice errors.

This method populates the provided figure axes with a bar-plot of the truncation error incurred
Expand Down Expand Up @@ -176,6 +180,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
num_observables = len(metadata.backpropagation_history[0].slice_errors)
width = 0.8 / num_observables
Expand All @@ -193,9 +198,10 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("incurred slice error")
axes.legend()
_set_legend(axes, show_legend)


def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
def plot_num_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
"""Plot the number of Pauli terms.

This method populates the provided figure axes with a line-plot of the number of Pauli terms at
Expand Down Expand Up @@ -229,6 +235,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)):
axes.plot(
Expand All @@ -238,10 +245,10 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
)
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("# Pauli terms")
axes.legend()
_set_legend(axes, show_legend)


def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
"""Plot the number of truncated Pauli terms.

This method populates the provided figure axes with a bar-plot of the number of the truncated
Expand Down Expand Up @@ -275,6 +282,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
num_observables = len(metadata.backpropagation_history[0].slice_errors)
width = 0.8 / num_observables
Expand All @@ -291,10 +299,10 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
offset += width
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("# truncated Pauli terms")
axes.legend()
_set_legend(axes, show_legend)


def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
def plot_sum_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
"""Plot the total number of all Pauli terms.

This method populates the provided figure axes with a line-plot of the total number of all Pauli
Expand Down Expand Up @@ -329,6 +337,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
if metadata.operator_budget.max_paulis is not None:
axes.axhline(
Expand All @@ -346,10 +355,10 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
)
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("total # of Pauli terms")
axes.legend()
_set_legend(axes, show_legend)


def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None:
"""Plot the number of qubit-wise commuting Pauli groups.

This method populates the provided figure axes with a line-plot of the number of qubit-wise
Expand Down Expand Up @@ -380,6 +389,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
Args:
metadata: the metadata to be visualized.
axes: the matplotlib axes in which to plot.
show_legend: enable/disable showing the legend in the plot.
"""
if metadata.operator_budget.max_qwc_groups is not None:
axes.axhline(
Expand All @@ -397,4 +407,9 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
)
axes.set_xlabel("backpropagated slice number")
axes.set_ylabel("# of qubit-wise commuting Pauli groups")
axes.legend()
_set_legend(axes, show_legend)


def _set_legend(axes: Axes, show_legend: bool) -> None:
if show_legend:
axes.legend()