Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Add diagnostic visualization tools
Browse files Browse the repository at this point in the history
This commit adds the feature to run diagnostic tools using ArviZ and
Bokeh in a Jupyter environment.

- Added a `tools` sub-package in the `diagnostics` package. This new
  sub-package adds the following files, each for a specific tool that
  runs ArviZ model diagnostics.

  - autocorrelation
  - effective_sample_size
  - marginal1d
  - marginal2d
  - trace

- The listed tools above also have two corresponding files, one for
  types and the other for methods used in the tool.

Resolves #1490
  • Loading branch information
ndmlny-qs committed Aug 29, 2022
1 parent d5d089a commit 168794f
Show file tree
Hide file tree
Showing 23 changed files with 7,803 additions and 185 deletions.
1 change: 1 addition & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from beanmachine.ppl.diagnostics.tools import viz
from torch.distributions import Distribution

from . import experimental
Expand Down
Empty file.
75 changes: 75 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Accessor definition for extending Bean Machine `MonteCarloSamples` objects."""
from __future__ import annotations

import contextlib
import warnings
from typing import Callable, TypeVar

from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples


T = TypeVar("T", bound="CachedAccessor")


class CachedAccessor:
"""A descriptor for caching accessors.
Parameters
----------
name : str
Namespace that will be accessed under, e.g. ``samples.accessor_name``.
accessor : cls
Class with the extension methods.
"""

def __init__(self: T, name: str, accessor: object) -> None:
"""Initialize."""
self._name = name
self._accessor = accessor

def __get__(self: T, obj: object, cls: object) -> object:
"""Access the accessor object."""
if obj is None:
return self._accessor

try:
cache = obj._cache # type: ignore
except AttributeError:
cache = obj._cache = {}

try:
return cache[self._name]
except KeyError:
contextlib.suppress(KeyError)

try:
accessor_obj = self._accessor(obj) # type: ignore
except Exception as error:
msg = f"error initializing {self._name!r} accessor."
raise RuntimeError(msg) from error

cache[self._name] = accessor_obj
return accessor_obj # noqa: R504


def _register_accessor(name: str, cls: object) -> Callable:
"""Register the accessor to the object."""

def decorator(accessor: object) -> object:
if hasattr(cls, name):
warnings.warn(
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(cls)} is overriding a preexisting "
f"attribute with the same name.",
UserWarning,
stacklevel=2,
)
setattr(cls, name, CachedAccessor(name, accessor))
return accessor

return decorator


def register_mcs_accessor(name: str) -> Callable:
"""Register the accessor to object."""
return _register_accessor(name, MonteCarloSamples)
96 changes: 96 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Autocorrelation diagnostic tool for a Bean Machine model."""
from __future__ import annotations

from typing import Any, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.autocorrelation as tool
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show

T = TypeVar("T", bound="Autocorrelation")


class Autocorrelation:
"""Autocorrelation diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]
self.num_draws = self.idata["posterior"].dims["draw"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures(self.num_chains)

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs(self.num_chains)
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(computed_data)
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names, self.num_draws)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)
end = 10 if self.num_draws <= 2 * 100 else 100
widgets["range_slider"].value = (0, end)

def update_range_slider(
attr: Any,
old: tuple[int, int],
new: tuple[int, int],
) -> None:
fig = figures[list(figures.keys())[0]]
fig.x_range.start, fig.x_range.end = new

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)
widgets["range_slider"].on_change("value", update_range_slider)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.
Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
87 changes: 87 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/effective_sample_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Effective Sample Size (ESS) diagnostic tool for a Bean Machine model."""
from __future__ import annotations

from typing import Any, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.effective_sample_size as tool
from bokeh.core.enums import LegendClickPolicy
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show


T = TypeVar("T", bound="EffectiveSampleSize")


class EffectiveSampleSize:
"""Effective Sample Size (ESS) diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures()

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs()
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(figures)
annotations["ess"]["legend"].click_policy = LegendClickPolicy.hide
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.
Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
Empty file.
Loading

0 comments on commit 168794f

Please sign in to comment.