Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Add diagnostic visualization tools #1631

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from beanmachine.ppl.diagnostics.tools import viz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to add this to __all__ so that the linter won't complain that this is imported but not use :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I need to setup pre-commit to point to setup.cfg as that's where the flake8 config exists. Do you know of a tool that will add copyright notices for Python?

from torch.distributions import Distribution

from . import experimental
Expand Down Expand Up @@ -60,4 +61,5 @@
"random_variable",
"simulate",
"split_r_hat",
"viz",
]
6 changes: 6 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Visual diagnostic tools for Bean Machine models."""
78 changes: 78 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Accessor definition for extending Bean Machine `MonteCarloSamples` objects."""
import contextlib
import warnings
from typing import Callable, TypeVar

from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples


T = TypeVar("T", bound="CachedAccessor")


class CachedAccessor:
"""A descriptor for caching accessors.

Parameters
----------
name : str
Namespace that will be accessed under, e.g. ``samples.accessor_name``.
accessor : cls
Class with the extension methods.
"""

def __init__(self: T, name: str, accessor: object) -> None:
"""Initialize."""
self._name = name
self._accessor = accessor

def __get__(self: T, obj: object, cls: object) -> object:
"""Access the accessor object."""
if obj is None:
return self._accessor

try:
cache = obj._cache # type: ignore
except AttributeError:
cache = obj._cache = {}

try:
return cache[self._name]
except KeyError:
contextlib.suppress(KeyError)

try:
accessor_obj = self._accessor(obj) # type: ignore
except Exception as error:
msg = f"error initializing {self._name!r} accessor."
raise RuntimeError(msg) from error

cache[self._name] = accessor_obj
return accessor_obj # noqa: R504


def _register_accessor(name: str, cls: object) -> Callable:
"""Register the accessor to the object."""

def decorator(accessor: object) -> object:
if hasattr(cls, name):
warnings.warn(
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(cls)} is overriding a preexisting "
f"attribute with the same name.",
UserWarning,
stacklevel=2,
)
setattr(cls, name, CachedAccessor(name, accessor))
return accessor

return decorator


def register_mcs_accessor(name: str) -> Callable:
"""Register the accessor to object."""
return _register_accessor(name, MonteCarloSamples)
99 changes: 99 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Autocorrelation diagnostic tool for a Bean Machine model."""
from typing import Any, Tuple, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.autocorrelation as tool
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show

T = TypeVar("T", bound="Autocorrelation")


class Autocorrelation:
"""Autocorrelation diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]
self.num_draws = self.idata["posterior"].dims["draw"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures(self.num_chains)

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs(self.num_chains)
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(computed_data)
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names, self.num_draws)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)
end = 10 if self.num_draws <= 2 * 100 else 100
widgets["range_slider"].value = (0, end)

def update_range_slider(
attr: Any,
old: Tuple[int, int],
new: Tuple[int, int],
) -> None:
fig = figures[list(figures.keys())[0]]
fig.x_range.start, fig.x_range.end = new

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)
widgets["range_slider"].on_change("value", update_range_slider)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.

Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
90 changes: 90 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/effective_sample_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Effective Sample Size (ESS) diagnostic tool for a Bean Machine model."""
from typing import Any, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.effective_sample_size as tool
from bokeh.core.enums import LegendClickPolicy
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show


T = TypeVar("T", bound="EffectiveSampleSize")


class EffectiveSampleSize:
"""Effective Sample Size (ESS) diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures()

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs()
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(figures)
annotations["ess"]["legend"].click_policy = LegendClickPolicy.hide
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.

Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
6 changes: 6 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Visual diagnostics tool methods."""
Loading