diff --git a/src/beanmachine/ppl/diagnostics/tools/__init__.py b/src/beanmachine/ppl/diagnostics/tools/__init__.py index 9cd2d21afd..43cad634da 100644 --- a/src/beanmachine/ppl/diagnostics/tools/__init__.py +++ b/src/beanmachine/ppl/diagnostics/tools/__init__.py @@ -6,7 +6,6 @@ # flake8: noqa """Visual diagnostic tools for Bean Machine models.""" - import sys from pathlib import Path @@ -16,6 +15,7 @@ # accepted, see https://peps.python.org/pep-0655/. This is to follow the # interface objects in JavaScript that allow keys to not be required using ?. from typing import TypedDict + from typing_extensions import NotRequired else: from typing_extensions import NotRequired, TypedDict diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts index e6c002fbd6..d4c85f5e28 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts @@ -72,7 +72,7 @@ export const computeStats = ( // Compute the point statistics for the KDE, and create labels to display them in the // figures. - const mean = computeMean(marginalX); + const mean = computeMean(rawData); const hdiBounds = hdiInterval(rawData, hdiProbability); const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound]; const y = interpolatePoints({x: marginalX, y: marginalY, points: x}); diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts index 394a632c24..40a4868571 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts @@ -1,4 +1,10 @@ -/* import {calculateHistogram} from 'compute-histogram'; */ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + import {linearRange, numericalSort, shape} from './array'; import {rankData, scaleToOne} from './dataTransformation'; import {mean as computeMean} from './pointStatistic'; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock b/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock index a6ef03786a..7f512032f9 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock +++ b/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock @@ -106,7 +106,7 @@ "@jridgewell/sourcemap-codec" "^1.4.10" "@jridgewell/trace-mapping" "^0.3.9" -"@jridgewell/resolve-uri@3.1.0", "@jridgewell/resolve-uri@^3.0.3": +"@jridgewell/resolve-uri@3.1.0": version "3.1.0" resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78" integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w== @@ -201,7 +201,7 @@ resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee" integrity sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ== -"@types/node@*", "@types/node@^18.0.4": +"@types/node@*": version "18.8.5" resolved "https://registry.yarnpkg.com/@types/node/-/node-18.8.5.tgz#6a31f820c1077c3f8ce44f9e203e68a176e8f59e" integrity sha512-Bq7G3AErwe5A/Zki5fdD3O6+0zDChhg671NfPjtIcbtzDNZTv4NPKMRFr7gtYPG7y+B8uTiNK4Ngd9T0FTar6Q== @@ -218,20 +218,6 @@ dependencies: "@types/jquery" "*" -"@typescript-eslint/eslint-plugin@^5.30.5": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.40.0.tgz#0159bb71410eec563968288a17bd4478cdb685bd" - integrity sha512-FIBZgS3DVJgqPwJzvZTuH4HNsZhHMa9SjxTKAZTlMsPw/UzpEjcf9f4dfgDJEHjK+HboUJo123Eshl6niwEm/Q== - dependencies: - "@typescript-eslint/scope-manager" "5.40.0" - "@typescript-eslint/type-utils" "5.40.0" - "@typescript-eslint/utils" "5.40.0" - debug "^4.3.4" - ignore "^5.2.0" - regexpp "^3.2.0" - semver "^7.3.7" - tsutils "^3.21.0" - "@typescript-eslint/parser@^5.30.5": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/parser/-/parser-5.40.0.tgz#432bddc1fe9154945660f67c1ba6d44de5014840" @@ -250,16 +236,6 @@ "@typescript-eslint/types" "5.40.0" "@typescript-eslint/visitor-keys" "5.40.0" -"@typescript-eslint/type-utils@5.40.0": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/type-utils/-/type-utils-5.40.0.tgz#4964099d0158355e72d67a370249d7fc03331126" - integrity sha512-nfuSdKEZY2TpnPz5covjJqav+g5qeBqwSHKBvz7Vm1SAfy93SwKk/JeSTymruDGItTwNijSsno5LhOHRS1pcfw== - dependencies: - "@typescript-eslint/typescript-estree" "5.40.0" - "@typescript-eslint/utils" "5.40.0" - debug "^4.3.4" - tsutils "^3.21.0" - "@typescript-eslint/types@5.40.0": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/types/-/types-5.40.0.tgz#8de07e118a10b8f63c99e174a3860f75608c822e" @@ -278,19 +254,6 @@ semver "^7.3.7" tsutils "^3.21.0" -"@typescript-eslint/utils@5.40.0": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/utils/-/utils-5.40.0.tgz#647f56a875fd09d33c6abd70913c3dd50759b772" - integrity sha512-MO0y3T5BQ5+tkkuYZJBjePewsY+cQnfkYeRqS6tPh28niiIwPnQ1t59CSRcs1ZwJJNOdWw7rv9pF8aP58IMihA== - dependencies: - "@types/json-schema" "^7.0.9" - "@typescript-eslint/scope-manager" "5.40.0" - "@typescript-eslint/types" "5.40.0" - "@typescript-eslint/typescript-estree" "5.40.0" - eslint-scope "^5.1.1" - eslint-utils "^3.0.0" - semver "^7.3.7" - "@typescript-eslint/visitor-keys@5.40.0": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/visitor-keys/-/visitor-keys-5.40.0.tgz#dd2d38097f68e0d2e1e06cb9f73c0173aca54b68" @@ -684,11 +647,6 @@ core-js-pure@^3.25.1: resolved "https://registry.yarnpkg.com/core-js-pure/-/core-js-pure-3.25.5.tgz#79716ba54240c6aa9ceba6eee08cf79471ba184d" integrity sha512-oml3M22pHM+igfWHDfdLVq2ShWmjM2V4L+dQEBs0DWVIqEm9WHCwGAlZ6BmyBQGy5sFrJmcx+856D9lVKyGWYg== -create-require@^1.1.0: - version "1.1.1" - resolved "https://registry.yarnpkg.com/create-require/-/create-require-1.1.1.tgz#c1d7e8f1e5f6cfc9ff65f9cd352d37348756c333" - integrity sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ== - cross-spawn@^7.0.2, cross-spawn@^7.0.3: version "7.0.3" resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" @@ -786,11 +744,6 @@ emoji-regex@^9.2.2: resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-9.2.2.tgz#840c8803b0d8047f4ff0cf963176b32d4ef3ed72" integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg== -emoji-regex@^9.2.2: - version "9.2.2" - resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-9.2.2.tgz#840c8803b0d8047f4ff0cf963176b32d4ef3ed72" - integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg== - enhanced-resolve@^5.0.0, enhanced-resolve@^5.10.0: version "5.10.0" resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-5.10.0.tgz#0dc579c3bb2a1032e357ac45b8f3a6f3ad4fb1e6" @@ -1630,11 +1583,6 @@ js-sdsl@^4.1.4: resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ== -"js-tokens@^3.0.0 || ^4.0.0": - version "4.0.0" - resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" - integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ== - js-yaml@^4.1.0: version "4.1.0" resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-4.1.0.tgz#c1fb65f8f5017901cdd2c951864ba18458a10602" @@ -2008,11 +1956,6 @@ prelude-ls@^1.2.1: resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.2.1.tgz#debc6489d7a6e6b0e7611888cec880337d316396" integrity sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g== -prettier@^2.7.1: - version "2.7.1" - resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.7.1.tgz#e235806850d057f97bb08368a4f7d899f7760c64" - integrity sha512-ujppO+MkdPqoVINuDFDRLClm7D78qbDt0/NR+wp5FqEZOoTNAjPHWj17QRhu7geIHJfcNhRk1XVQmF8Bp3ye+g== - proj4@^2.7.5: version "2.8.0" resolved "https://registry.yarnpkg.com/proj4/-/proj4-2.8.0.tgz#b2cb8f3ccd56d4dcc7c3e46155cd02caa804b170" diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py index e6ce876e1c..1e122cba87 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py @@ -4,8 +4,7 @@ # LICENSE file in the root directory of this source tree. """Marginal 1D diagnostic tool for a Bean Machine model.""" - -from typing import TypeVar +from __future__ import annotations from beanmachine.ppl.diagnostics.tools.marginal1d import utils from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( @@ -16,9 +15,6 @@ from bokeh.models.callbacks import CustomJS -T = TypeVar("T", bound="Marginal1d") - - class Marginal1d(DiagnosticToolBaseClass): """ Marginal 1D diagnostic tool. @@ -40,10 +36,10 @@ class Marginal1d(DiagnosticToolBaseClass): independently from a Python server. """ - def __init__(self: T, mcs: MonteCarloSamples) -> None: + def __init__(self: Marginal1d, mcs: MonteCarloSamples) -> None: super(Marginal1d, self).__init__(mcs) - def create_document(self: T) -> Model: + def create_document(self: Marginal1d) -> Model: # Initialize widget values using Python. rv_name = self.rv_names[0] bw_factor = 1.0 diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py index 881cd308e4..8796cc70e3 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Marginal 1D diagnostic tool types for a Bean Machine model.""" - from typing import Any, Dict, List, Union from beanmachine.ppl.diagnostics.tools import TypedDict diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py index 6ef4dfebf7..ff148175c3 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Methods used to generate the diagnostic tool.""" - from typing import List import numpy as np diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/tool.py b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py index 7c9555e0b7..710acc38b5 100644 --- a/src/beanmachine/ppl/diagnostics/tools/trace/tool.py +++ b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py @@ -4,8 +4,7 @@ # LICENSE file in the root directory of this source tree. """Trace diagnostic tool for a Bean Machine model.""" - -from typing import TypeVar +from __future__ import annotations from beanmachine.ppl.diagnostics.tools.trace import utils from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( @@ -16,9 +15,6 @@ from bokeh.models.callbacks import CustomJS -T = TypeVar("T", bound="Trace") - - class Trace(DiagnosticToolBaseClass): """Trace tool. @@ -39,10 +35,10 @@ class Trace(DiagnosticToolBaseClass): independently from a Python server. """ - def __init__(self: T, mcs: MonteCarloSamples) -> None: + def __init__(self: Trace, mcs: MonteCarloSamples) -> None: super(Trace, self).__init__(mcs) - def create_document(self: T) -> Model: + def create_document(self: Trace) -> Model: # Initialize widget values using Python. rv_name = self.rv_names[0] @@ -88,10 +84,22 @@ def create_document(self: T) -> Model: # Create the widgets for the tool using Python. widgets = utils.create_widgets(rv_names=self.rv_names, rv_name=rv_name) + # Create the view of the tool and serialize it into HTML using static resources + # from Bokeh. Embedding the tool in this manner prevents external CDN calls for + # JavaScript resources, and prevents the user from having to know where the + # Bokeh server is. + tool_view = utils.create_view(figures=figures, widgets=widgets) + # Create callbacks for the tool using JavaScript. callback_js = f""" const rvName = widgets.rv_select.value; const rvData = data[rvName]; + let bw = 0.0; + // Remove the CSS classes that dim the tool output on initial load. + const toolTab = toolView.tabs[0]; + const toolChildren = toolTab.child.children; + const dimmedComponent = toolChildren[1]; + dimmedComponent.css_classes = []; try {{ trace.update( rvData, @@ -125,6 +133,7 @@ def create_document(self: T) -> Model: "sources": sources, "figures": figures, "tooltips": tooltips, + "toolView": tool_view, } # Each widget requires slightly different JS. @@ -149,5 +158,4 @@ def create_document(self: T) -> Model: widgets["bw_factor_slider"].js_on_change("value", slider_callback) widgets["hdi_slider"].js_on_change("value", slider_callback) - tool_view = utils.create_view(figures=figures, widgets=widgets) return tool_view diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/typing.py b/src/beanmachine/ppl/diagnostics/tools/trace/typing.py index 85d9ab5f14..1841bc7c70 100644 --- a/src/beanmachine/ppl/diagnostics/tools/trace/typing.py +++ b/src/beanmachine/ppl/diagnostics/tools/trace/typing.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Trace diagnostic tool types for a Bean Machine model.""" - from typing import Any, Dict, List, Union from beanmachine.ppl.diagnostics.tools import NotRequired, TypedDict diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/utils.py b/src/beanmachine/ppl/diagnostics/tools/trace/utils.py index 58f95c670a..df99a5cb0a 100644 --- a/src/beanmachine/ppl/diagnostics/tools/trace/utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/trace/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Methods used to generate the diagnostic tool.""" - from typing import List from beanmachine.ppl.diagnostics.tools.trace import typing @@ -201,7 +200,6 @@ def create_figures(rv_name: str, num_chains: int) -> typing.Figures: if figure_name == "marginals": fig.title = "Marginal" fig.xaxis.axis_label = rv_name - # fig.x_range = Range1d() fig.yaxis.visible = False elif figure_name == "forests": fig.title = "Forest" @@ -209,13 +207,11 @@ def create_figures(rv_name: str, num_chains: int) -> typing.Figures: fig.yaxis.axis_label = "Chain" fig.yaxis.minor_tick_line_color = None fig.yaxis.ticker.desired_num_ticks = num_chains - # fig.x_range = Range1d() elif figure_name == "traces": fig.title = "Trace" fig.xaxis.axis_label = "Draw from single chain" fig.yaxis.axis_label = rv_name fig.width = TRACE_PLOT_WIDTH - # fig.x_range = Range1d() elif figure_name == "ranks": fig.title = "Rank" fig.xaxis.axis_label = "Rank from all chains" @@ -399,8 +395,6 @@ def add_glyphs( None Adds data bound glyphs to the given figures directly. """ - # range_min = [] - # range_max = [] for figure_name, figure_sources in sources.items(): fig = figures[figure_name] for chain_name, source in figure_sources.items(): @@ -417,11 +411,6 @@ def add_glyphs( # its range stable are linked to the marginal figure's range below. if figure_name == "marginals": pass - # data = source["line"].data["x"] - # minimum = min(data) if len(data) != 0 else 0 - # maximum = max(data) if len(data) != 0 else 1 - # range_min.append(minimum) - # range_max.append(maximum) elif figure_name == "forests": fig.add_glyph( source_or_glyph=source["circle"], @@ -437,12 +426,7 @@ def add_glyphs( name=chain_glyphs["quad"]["glyph"].name, ) # Link figure ranges together. - # figures["marginals"].x_range = Range1d( - # start=min(range_min) if len(range_min) != 0 else 0, - # end=max(range_max) if len(range_max) != 0 else 1, - # ) figures["forests"].x_range = figures["marginals"].x_range - # figures["traces"].y_range = figures["marginals"].x_range def create_annotations(figures: typing.Figures, num_chains: int) -> typing.Annotations: @@ -573,13 +557,15 @@ def create_tooltips( { "line": HoverTool( renderers=plotting_utils.filter_renderers( - fig, f"{figure_name}{chain_name.title()}LineGlyph" + fig, + f"{figure_name}{chain_name.title()}LineGlyph", ), tooltips=[("Chain", "@chain"), ("Rank mean", "@rankMean")], ), "quad": HoverTool( renderers=plotting_utils.filter_renderers( - fig, f"{figure_name}{chain_name.title()}QuadGlyph" + fig, + f"{figure_name}{chain_name.title()}QuadGlyph", ), tooltips=[ ("Chain", "@chain"), @@ -587,7 +573,7 @@ def create_tooltips( ("Rank", "@rank"), ], ), - } + }, ) return output diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py b/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py index e4e23f6bc6..9146b3c8aa 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py @@ -10,17 +10,15 @@ - `pandas`: https://github.com/pandas-dev/pandas/blob/main/pandas/core/accessor.py - `xarray`: https://github.com/pydata/xarray/blob/main/xarray/core/extensions.py """ +from __future__ import annotations import contextlib import warnings -from typing import Callable, TypeVar +from typing import Callable from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples -T = TypeVar("T", bound="CachedAccessor") - - class CachedAccessor: """ A descriptor for caching accessors. @@ -38,11 +36,11 @@ class CachedAccessor: object. """ - def __init__(self: T, name: str, accessor: object) -> None: + def __init__(self: CachedAccessor, name: str, accessor: object) -> None: self._name = name self._accessor = accessor - def __get__(self: T, obj: object, cls: object) -> object: + def __get__(self: CachedAccessor, obj: object, cls: object) -> object: """ Method to retrieve the accessor namespace. @@ -123,7 +121,8 @@ def register_mcs_accessor(name: str) -> Callable: object. Example: - >>> from typing import Dict, List, TypeVar + >>> from __future__ import annotations + >>> from typing import Dict, List >>> >>> import beanmachine.ppl as bm >>> import numpy as np @@ -132,8 +131,6 @@ def register_mcs_accessor(name: str) -> Callable: >>> from beanmachine.ppl.diagnostics.tools.utils import accessor >>> from torch import tensor >>> - >>> T = TypeVar("T", bound="MagicAccessor") - >>> >>> @bm.random_variable >>> def alpha(): >>> return dist.Normal(0, 1) @@ -144,9 +141,9 @@ def register_mcs_accessor(name: str) -> Callable: >>> >>> @accessor.register_mcs_accessor("magic") >>> class MagicAccessor: - >>> def __init__(self: T, mcs: MonteCarloSamples) -> None: + >>> def __init__(self: MagicAccessor, mcs: MonteCarloSamples) -> None: >>> self.mcs = mcs - >>> def show_me(self: T) -> Dict[str, List[List[float]]]: + >>> def show_me(self: MagicAccessor) -> Dict[str, List[List[float]]]: >>> # Return a JSON serializable object from a MonteCarloSamples object. >>> return dict( >>> sorted( diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py b/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py index d423bd06b2..b4c4b79fe1 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. """Base class for diagnostic tools of a Bean Machine model.""" +from __future__ import annotations import re from abc import ABC, abstractmethod -from typing import Any, Mapping, TypeVar +from typing import Any, Mapping from beanmachine.ppl.diagnostics.tools import JS_DIST_DIR from beanmachine.ppl.diagnostics.tools.utils import plotting_utils @@ -18,9 +19,6 @@ from bokeh.resources import INLINE -T = TypeVar("T", bound="DiagnosticToolBaseClass") - - class DiagnosticToolBaseClass(ABC): """ Base class for visual diagnostic tools. @@ -43,7 +41,7 @@ class DiagnosticToolBaseClass(ABC): """ @abstractmethod - def __init__(self: T, mcs: MonteCarloSamples) -> None: + def __init__(self: DiagnosticToolBaseClass, mcs: MonteCarloSamples) -> None: self.data = serialize_bm(mcs) self.rv_names = ["Select a random variable..."] + list(self.data.keys()) self.num_chains = mcs.num_chains @@ -51,7 +49,7 @@ def __init__(self: T, mcs: MonteCarloSamples) -> None: self.palette = plotting_utils.choose_palette(self.num_chains) self.tool_js = self.load_tool_js() - def load_tool_js(self: T) -> str: + def load_tool_js(self: DiagnosticToolBaseClass) -> str: """ Load the JavaScript for the diagnostic tool. @@ -75,7 +73,7 @@ def load_tool_js(self: T) -> str: tool_js = f.read() return tool_js - def show(self: T) -> None: + def show(self: DiagnosticToolBaseClass) -> None: """ Show the diagnostic tool in the notebook. @@ -97,7 +95,7 @@ def show(self: T) -> None: html = file_html(doc, resources=INLINE, template=self.html_template()) display(HTML(html)) - def html_template(self: T) -> str: + def html_template(self: DiagnosticToolBaseClass) -> str: """ HTML template object used to inject CSS styles for Bokeh Applications. @@ -145,11 +143,11 @@ def html_template(self: T) -> str: """ @abstractmethod - def create_document(self: T) -> Model: + def create_document(self: DiagnosticToolBaseClass) -> Model: """To be implemented by the inheriting class.""" ... - def _tool_json(self: T) -> Mapping[Any, Any]: + def _tool_json(self: DiagnosticToolBaseClass) -> Mapping[Any, Any]: """ Debugging method used primarily when creating a new diagnostic tool. diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py index 58e96dd3a6..27f1d1f8be 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Collection of serializers for the diagnostics tool use.""" - from typing import Dict, List from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples @@ -25,6 +24,7 @@ def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]: rv_data = samples[rv_identifier] rv_shape = rv_data.shape num_rv_chains = rv_shape[0] + reshaped_data[f"{str(rv_identifier)}"] = [] for rv_chain in range(num_rv_chains): chain_data = rv_data[rv_chain, :] chain_shape = chain_data.shape @@ -40,8 +40,12 @@ def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]: reshape_dimensions = chain_shape[1] for i, reshape_dimension in enumerate(range(reshape_dimensions)): data = rv_data[rv_chain, :, reshape_dimension].reshape(-1) - reshaped_data[f"{str(rv_identifier)}[{i}]"] = data.tolist() + if f"{str(rv_identifier)}[{i}]" not in reshaped_data: + reshaped_data[f"{str(rv_identifier)}[{i}]"] = [] + reshaped_data[f"{str(rv_identifier)}[{i}]"].append(data.tolist()) elif len(chain_shape) == 1: - reshaped_data[f"{str(rv_identifier)}"] = rv_data[rv_chain, :].tolist() + reshaped_data[f"{str(rv_identifier)}"].append( + rv_data[rv_chain, :].tolist(), + ) model = dict(sorted(reshaped_data.items(), key=lambda item: item[0])) return model diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py b/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py index 82b48f2a1f..c116ef778a 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Plotting utilities for the diagnostic tools.""" - from typing import List from bokeh.core.property.nullable import Nullable diff --git a/src/beanmachine/ppl/diagnostics/tools/viz.py b/src/beanmachine/ppl/diagnostics/tools/viz.py index d92fcb8f4d..c149ba9d7e 100644 --- a/src/beanmachine/ppl/diagnostics/tools/viz.py +++ b/src/beanmachine/ppl/diagnostics/tools/viz.py @@ -49,21 +49,24 @@ def __init__(self: DiagnosticsTools, mcs: MonteCarloSamples) -> None: @_requires_dev_packages def marginal1d(self: DiagnosticsTools) -> None: - """Marginal 1D tool.""" + """ + Marginal 1D diagnostic tool for a Bean Machine model. + + Returns: + None: Displays the tool directly in a Jupyter notebook. + """ from beanmachine.ppl.diagnostics.tools.marginal1d.tool import Marginal1d Marginal1d(self.mcs).show() - def trace(self: T, name: Optional[str] = None) -> None: + @_requires_dev_packages + def trace(self: DiagnosticsTools) -> None: """ Trace diagnostic tool for a Bean Machine model. - Args: - name (:obj:`str`, optional): Optional name for the tool. This is used to - persist data as JSON to disk when converting a Jupyter notebook to an - MDX file. - Returns: None: Displays the tool directly in a Jupyter notebook. """ - Trace(self.mcs).show(name=name) + from beanmachine.ppl.diagnostics.tools.trace.tool import Trace + + Trace(self.mcs).show()