This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit adds the feature to run diagnostic tools using ArviZ and Bokeh in a Jupyter environment. - Added a `tools` sub-package in the `diagnostics` package. This new sub-package adds the following files, each for a specific tool that runs ArviZ model diagnostics. - autocorrelation - effective_sample_size - marginal1d - marginal2d - trace - The listed tools above also have two corresponding files, one for types and the other for methods used in the tool. Resolves #1490
- Loading branch information
Showing
23 changed files
with
7,803 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Accessor definition for extending Bean Machine `MonteCarloSamples` objects.""" | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import warnings | ||
from typing import Callable, TypeVar | ||
|
||
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples | ||
|
||
|
||
T = TypeVar("T", bound="CachedAccessor") | ||
|
||
|
||
class CachedAccessor: | ||
"""A descriptor for caching accessors. | ||
Parameters | ||
---------- | ||
name : str | ||
Namespace that will be accessed under, e.g. ``samples.accessor_name``. | ||
accessor : cls | ||
Class with the extension methods. | ||
""" | ||
|
||
def __init__(self: T, name: str, accessor: object) -> None: | ||
"""Initialize.""" | ||
self._name = name | ||
self._accessor = accessor | ||
|
||
def __get__(self: T, obj: object, cls: object) -> object: | ||
"""Access the accessor object.""" | ||
if obj is None: | ||
return self._accessor | ||
|
||
try: | ||
cache = obj._cache # type: ignore | ||
except AttributeError: | ||
cache = obj._cache = {} | ||
|
||
try: | ||
return cache[self._name] | ||
except KeyError: | ||
contextlib.suppress(KeyError) | ||
|
||
try: | ||
accessor_obj = self._accessor(obj) # type: ignore | ||
except Exception as error: | ||
msg = f"error initializing {self._name!r} accessor." | ||
raise RuntimeError(msg) from error | ||
|
||
cache[self._name] = accessor_obj | ||
return accessor_obj # noqa: R504 | ||
|
||
|
||
def _register_accessor(name: str, cls: object) -> Callable: | ||
"""Register the accessor to the object.""" | ||
|
||
def decorator(accessor: object) -> object: | ||
if hasattr(cls, name): | ||
warnings.warn( | ||
f"registration of accessor {repr(accessor)} under name " | ||
f"{repr(name)} for type {repr(cls)} is overriding a preexisting " | ||
f"attribute with the same name.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
setattr(cls, name, CachedAccessor(name, accessor)) | ||
return accessor | ||
|
||
return decorator | ||
|
||
|
||
def register_mcs_accessor(name: str) -> Callable: | ||
"""Register the accessor to object.""" | ||
return _register_accessor(name, MonteCarloSamples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Autocorrelation diagnostic tool for a Bean Machine model.""" | ||
from __future__ import annotations | ||
|
||
from typing import Any, TypeVar | ||
|
||
import arviz as az | ||
|
||
import beanmachine.ppl.diagnostics.tools.helpers.autocorrelation as tool | ||
from bokeh.models.callbacks import CustomJS | ||
from bokeh.plotting import show | ||
|
||
T = TypeVar("T", bound="Autocorrelation") | ||
|
||
|
||
class Autocorrelation: | ||
"""Autocorrelation diagnostic tool.""" | ||
|
||
def __init__(self: T, idata: az.InferenceData) -> None: | ||
"""Initialize.""" | ||
self.idata = idata | ||
self.rv_identifiers = list(self.idata["posterior"].data_vars) | ||
self.rv_names = sorted( | ||
[str(rv_identifier) for rv_identifier in self.rv_identifiers], | ||
) | ||
self.num_chains = self.idata["posterior"].dims["chain"] | ||
self.num_draws = self.idata["posterior"].dims["draw"] | ||
|
||
def modify_doc(self: T, doc: Any) -> None: | ||
"""Modify the Jupyter document in order to display the tool.""" | ||
# Initialize the widgets. | ||
rv_name = self.rv_names[0] | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
|
||
# Compute the initial data displayed in the tool. | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
computed_data = tool.compute_data(rv_data) | ||
|
||
# Create the Bokeh source(s). | ||
sources = tool.create_sources(computed_data) | ||
|
||
# Create the figure(s). | ||
figures = tool.create_figures(self.num_chains) | ||
|
||
# Create the glyph(s) and attach them to the figure(s). | ||
glyphs = tool.create_glyphs(self.num_chains) | ||
tool.add_glyphs(figures, glyphs, sources) | ||
|
||
# Create the annotation(s) and attache them to the figure(s). | ||
annotations = tool.create_annotations(computed_data) | ||
tool.add_annotations(figures, annotations) | ||
|
||
# Create the tool tip(s) and attach them to the figure(s). | ||
tooltips = tool.create_tooltips(figures) | ||
tool.add_tooltips(figures, tooltips) | ||
|
||
# Create the widget(s) for the tool. | ||
widgets = tool.create_widgets(rv_name, self.rv_names, self.num_draws) | ||
|
||
# Create the callback(s) for the widget(s). | ||
def update_rv_select(attr: Any, old: str, new: str) -> None: | ||
rv_name = new | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
tool.update(rv_data, sources) | ||
end = 10 if self.num_draws <= 2 * 100 else 100 | ||
widgets["range_slider"].value = (0, end) | ||
|
||
def update_range_slider( | ||
attr: Any, | ||
old: tuple[int, int], | ||
new: tuple[int, int], | ||
) -> None: | ||
fig = figures[list(figures.keys())[0]] | ||
fig.x_range.start, fig.x_range.end = new | ||
|
||
widgets["rv_select"].on_change("value", update_rv_select) | ||
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the | ||
# figures. | ||
widgets["rv_select"].js_on_change( | ||
"value", | ||
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"), | ||
) | ||
widgets["range_slider"].on_change("value", update_range_slider) | ||
|
||
tool_view = tool.create_view(widgets, figures) | ||
doc.add_root(tool_view) | ||
|
||
def show_tool(self: T) -> None: | ||
"""Show the diagnostic tool. | ||
Returns | ||
------- | ||
None | ||
Directly displays the tool in Jupyter. | ||
""" | ||
show(self.modify_doc) |
87 changes: 87 additions & 0 deletions
87
src/beanmachine/ppl/diagnostics/tools/effective_sample_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Effective Sample Size (ESS) diagnostic tool for a Bean Machine model.""" | ||
from __future__ import annotations | ||
|
||
from typing import Any, TypeVar | ||
|
||
import arviz as az | ||
|
||
import beanmachine.ppl.diagnostics.tools.helpers.effective_sample_size as tool | ||
from bokeh.core.enums import LegendClickPolicy | ||
from bokeh.models.callbacks import CustomJS | ||
from bokeh.plotting import show | ||
|
||
|
||
T = TypeVar("T", bound="EffectiveSampleSize") | ||
|
||
|
||
class EffectiveSampleSize: | ||
"""Effective Sample Size (ESS) diagnostic tool.""" | ||
|
||
def __init__(self: T, idata: az.InferenceData) -> None: | ||
"""Initialize.""" | ||
self.idata = idata | ||
self.rv_identifiers = list(self.idata["posterior"].data_vars) | ||
self.rv_names = sorted( | ||
[str(rv_identifier) for rv_identifier in self.rv_identifiers], | ||
) | ||
self.num_chains = self.idata["posterior"].dims["chain"] | ||
|
||
def modify_doc(self: T, doc: Any) -> None: | ||
"""Modify the Jupyter document in order to display the tool.""" | ||
# Initialize the widgets. | ||
rv_name = self.rv_names[0] | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
|
||
# Compute the initial data displayed in the tool. | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
computed_data = tool.compute_data(rv_data) | ||
|
||
# Create the Bokeh source(s). | ||
sources = tool.create_sources(computed_data) | ||
|
||
# Create the figure(s). | ||
figures = tool.create_figures() | ||
|
||
# Create the glyph(s) and attach them to the figure(s). | ||
glyphs = tool.create_glyphs() | ||
tool.add_glyphs(figures, glyphs, sources) | ||
|
||
# Create the annotation(s) and attache them to the figure(s). | ||
annotations = tool.create_annotations(figures) | ||
annotations["ess"]["legend"].click_policy = LegendClickPolicy.hide | ||
tool.add_annotations(figures, annotations) | ||
|
||
# Create the tool tip(s) and attach them to the figure(s). | ||
tooltips = tool.create_tooltips(figures) | ||
tool.add_tooltips(figures, tooltips) | ||
|
||
# Create the widget(s) for the tool. | ||
widgets = tool.create_widgets(rv_name, self.rv_names) | ||
|
||
# Create the callback(s) for the widget(s). | ||
def update_rv_select(attr: Any, old: str, new: str) -> None: | ||
rv_name = new | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
tool.update(rv_data, sources) | ||
|
||
widgets["rv_select"].on_change("value", update_rv_select) | ||
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the | ||
# figures. | ||
widgets["rv_select"].js_on_change( | ||
"value", | ||
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"), | ||
) | ||
|
||
tool_view = tool.create_view(widgets, figures) | ||
doc.add_root(tool_view) | ||
|
||
def show_tool(self: T) -> None: | ||
"""Show the diagnostic tool. | ||
Returns | ||
------- | ||
None | ||
Directly displays the tool in Jupyter. | ||
""" | ||
show(self.modify_doc) |
Empty file.
Oops, something went wrong.