This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
Add diagnostic visualization tools #1631
Open
ndmlny-qs
wants to merge
8
commits into
main
Choose a base branch
from
issue1490/diagnostic-tools
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
168794f
Add diagnostic visualization tools
ndmlny-qs f71b069
Revert typing
ndmlny-qs 58799d6
Add copyright notice
ndmlny-qs 51c1719
Fix linting errors
ndmlny-qs 421fe45
Fix bug associated with reverted types
ndmlny-qs 17e95b7
Rebase and fix more type bugs
ndmlny-qs 3d4aa71
Make flake8 ignore py37 TypedDict unused import
ndmlny-qs 5cb40e3
Merge branch 'main' into issue1490/diagnostic-tools
ndmlny-qs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Visual diagnostic tools for Bean Machine models.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Accessor definition for extending Bean Machine `MonteCarloSamples` objects.""" | ||
import contextlib | ||
import warnings | ||
from typing import Callable, TypeVar | ||
|
||
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples | ||
|
||
|
||
T = TypeVar("T", bound="CachedAccessor") | ||
|
||
|
||
class CachedAccessor: | ||
"""A descriptor for caching accessors. | ||
|
||
Parameters | ||
---------- | ||
name : str | ||
Namespace that will be accessed under, e.g. ``samples.accessor_name``. | ||
accessor : cls | ||
Class with the extension methods. | ||
""" | ||
|
||
def __init__(self: T, name: str, accessor: object) -> None: | ||
"""Initialize.""" | ||
self._name = name | ||
self._accessor = accessor | ||
|
||
def __get__(self: T, obj: object, cls: object) -> object: | ||
"""Access the accessor object.""" | ||
if obj is None: | ||
return self._accessor | ||
|
||
try: | ||
cache = obj._cache # type: ignore | ||
except AttributeError: | ||
cache = obj._cache = {} | ||
|
||
try: | ||
return cache[self._name] | ||
except KeyError: | ||
contextlib.suppress(KeyError) | ||
|
||
try: | ||
accessor_obj = self._accessor(obj) # type: ignore | ||
except Exception as error: | ||
msg = f"error initializing {self._name!r} accessor." | ||
raise RuntimeError(msg) from error | ||
|
||
cache[self._name] = accessor_obj | ||
return accessor_obj # noqa: R504 | ||
|
||
|
||
def _register_accessor(name: str, cls: object) -> Callable: | ||
"""Register the accessor to the object.""" | ||
|
||
def decorator(accessor: object) -> object: | ||
if hasattr(cls, name): | ||
warnings.warn( | ||
f"registration of accessor {repr(accessor)} under name " | ||
f"{repr(name)} for type {repr(cls)} is overriding a preexisting " | ||
f"attribute with the same name.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
setattr(cls, name, CachedAccessor(name, accessor)) | ||
return accessor | ||
|
||
return decorator | ||
|
||
|
||
def register_mcs_accessor(name: str) -> Callable: | ||
"""Register the accessor to object.""" | ||
return _register_accessor(name, MonteCarloSamples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Autocorrelation diagnostic tool for a Bean Machine model.""" | ||
from typing import Any, Tuple, TypeVar | ||
|
||
import arviz as az | ||
|
||
import beanmachine.ppl.diagnostics.tools.helpers.autocorrelation as tool | ||
from bokeh.models.callbacks import CustomJS | ||
from bokeh.plotting import show | ||
|
||
T = TypeVar("T", bound="Autocorrelation") | ||
|
||
|
||
class Autocorrelation: | ||
"""Autocorrelation diagnostic tool.""" | ||
|
||
def __init__(self: T, idata: az.InferenceData) -> None: | ||
"""Initialize.""" | ||
self.idata = idata | ||
self.rv_identifiers = list(self.idata["posterior"].data_vars) | ||
self.rv_names = sorted( | ||
[str(rv_identifier) for rv_identifier in self.rv_identifiers], | ||
) | ||
self.num_chains = self.idata["posterior"].dims["chain"] | ||
self.num_draws = self.idata["posterior"].dims["draw"] | ||
|
||
def modify_doc(self: T, doc: Any) -> None: | ||
"""Modify the Jupyter document in order to display the tool.""" | ||
# Initialize the widgets. | ||
rv_name = self.rv_names[0] | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
|
||
# Compute the initial data displayed in the tool. | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
computed_data = tool.compute_data(rv_data) | ||
|
||
# Create the Bokeh source(s). | ||
sources = tool.create_sources(computed_data) | ||
|
||
# Create the figure(s). | ||
figures = tool.create_figures(self.num_chains) | ||
|
||
# Create the glyph(s) and attach them to the figure(s). | ||
glyphs = tool.create_glyphs(self.num_chains) | ||
tool.add_glyphs(figures, glyphs, sources) | ||
|
||
# Create the annotation(s) and attache them to the figure(s). | ||
annotations = tool.create_annotations(computed_data) | ||
tool.add_annotations(figures, annotations) | ||
|
||
# Create the tool tip(s) and attach them to the figure(s). | ||
tooltips = tool.create_tooltips(figures) | ||
tool.add_tooltips(figures, tooltips) | ||
|
||
# Create the widget(s) for the tool. | ||
widgets = tool.create_widgets(rv_name, self.rv_names, self.num_draws) | ||
|
||
# Create the callback(s) for the widget(s). | ||
def update_rv_select(attr: Any, old: str, new: str) -> None: | ||
rv_name = new | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
tool.update(rv_data, sources) | ||
end = 10 if self.num_draws <= 2 * 100 else 100 | ||
widgets["range_slider"].value = (0, end) | ||
|
||
def update_range_slider( | ||
attr: Any, | ||
old: Tuple[int, int], | ||
new: Tuple[int, int], | ||
) -> None: | ||
fig = figures[list(figures.keys())[0]] | ||
fig.x_range.start, fig.x_range.end = new | ||
|
||
widgets["rv_select"].on_change("value", update_rv_select) | ||
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the | ||
# figures. | ||
widgets["rv_select"].js_on_change( | ||
"value", | ||
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"), | ||
) | ||
widgets["range_slider"].on_change("value", update_range_slider) | ||
|
||
tool_view = tool.create_view(widgets, figures) | ||
doc.add_root(tool_view) | ||
|
||
def show_tool(self: T) -> None: | ||
"""Show the diagnostic tool. | ||
|
||
Returns | ||
------- | ||
None | ||
Directly displays the tool in Jupyter. | ||
""" | ||
show(self.modify_doc) |
90 changes: 90 additions & 0 deletions
90
src/beanmachine/ppl/diagnostics/tools/effective_sample_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Effective Sample Size (ESS) diagnostic tool for a Bean Machine model.""" | ||
from typing import Any, TypeVar | ||
|
||
import arviz as az | ||
|
||
import beanmachine.ppl.diagnostics.tools.helpers.effective_sample_size as tool | ||
from bokeh.core.enums import LegendClickPolicy | ||
from bokeh.models.callbacks import CustomJS | ||
from bokeh.plotting import show | ||
|
||
|
||
T = TypeVar("T", bound="EffectiveSampleSize") | ||
|
||
|
||
class EffectiveSampleSize: | ||
"""Effective Sample Size (ESS) diagnostic tool.""" | ||
|
||
def __init__(self: T, idata: az.InferenceData) -> None: | ||
"""Initialize.""" | ||
self.idata = idata | ||
self.rv_identifiers = list(self.idata["posterior"].data_vars) | ||
self.rv_names = sorted( | ||
[str(rv_identifier) for rv_identifier in self.rv_identifiers], | ||
) | ||
self.num_chains = self.idata["posterior"].dims["chain"] | ||
|
||
def modify_doc(self: T, doc: Any) -> None: | ||
"""Modify the Jupyter document in order to display the tool.""" | ||
# Initialize the widgets. | ||
rv_name = self.rv_names[0] | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
|
||
# Compute the initial data displayed in the tool. | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
computed_data = tool.compute_data(rv_data) | ||
|
||
# Create the Bokeh source(s). | ||
sources = tool.create_sources(computed_data) | ||
|
||
# Create the figure(s). | ||
figures = tool.create_figures() | ||
|
||
# Create the glyph(s) and attach them to the figure(s). | ||
glyphs = tool.create_glyphs() | ||
tool.add_glyphs(figures, glyphs, sources) | ||
|
||
# Create the annotation(s) and attache them to the figure(s). | ||
annotations = tool.create_annotations(figures) | ||
annotations["ess"]["legend"].click_policy = LegendClickPolicy.hide | ||
tool.add_annotations(figures, annotations) | ||
|
||
# Create the tool tip(s) and attach them to the figure(s). | ||
tooltips = tool.create_tooltips(figures) | ||
tool.add_tooltips(figures, tooltips) | ||
|
||
# Create the widget(s) for the tool. | ||
widgets = tool.create_widgets(rv_name, self.rv_names) | ||
|
||
# Create the callback(s) for the widget(s). | ||
def update_rv_select(attr: Any, old: str, new: str) -> None: | ||
rv_name = new | ||
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)] | ||
rv_data = self.idata["posterior"][rv_identifier].values | ||
tool.update(rv_data, sources) | ||
|
||
widgets["rv_select"].on_change("value", update_rv_select) | ||
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the | ||
# figures. | ||
widgets["rv_select"].js_on_change( | ||
"value", | ||
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"), | ||
) | ||
|
||
tool_view = tool.create_view(widgets, figures) | ||
doc.add_root(tool_view) | ||
|
||
def show_tool(self: T) -> None: | ||
"""Show the diagnostic tool. | ||
|
||
Returns | ||
------- | ||
None | ||
Directly displays the tool in Jupyter. | ||
""" | ||
show(self.modify_doc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Visual diagnostics tool methods.""" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to add this to
__all__
so that the linter won't complain that this is imported but not use :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I need to setup pre-commit to point to setup.cfg as that's where the flake8 config exists. Do you know of a tool that will add copyright notices for Python?